• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static com.google.common.truth.Truth.assertAbout;
20 import static io.grpc.testing.DeadlineSubject.deadline;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertNotSame;
23 import static org.junit.Assert.assertTrue;
24 import static org.junit.Assert.fail;
25 
26 import com.google.common.util.concurrent.SettableFuture;
27 import io.grpc.Context;
28 import io.grpc.Context.CancellableContext;
29 import io.grpc.Deadline;
30 import io.grpc.ManagedChannel;
31 import io.grpc.Metadata;
32 import io.grpc.Server;
33 import io.grpc.ServerCall;
34 import io.grpc.ServerCallHandler;
35 import io.grpc.ServerInterceptor;
36 import io.grpc.ServerInterceptors;
37 import io.grpc.Status;
38 import io.grpc.StatusRuntimeException;
39 import io.grpc.inprocess.InProcessChannelBuilder;
40 import io.grpc.inprocess.InProcessServerBuilder;
41 import io.grpc.stub.ServerCallStreamObserver;
42 import io.grpc.stub.StreamObserver;
43 import io.grpc.testing.integration.Messages.SimpleRequest;
44 import io.grpc.testing.integration.Messages.SimpleResponse;
45 import java.io.IOException;
46 import java.util.concurrent.CountDownLatch;
47 import java.util.concurrent.ExecutionException;
48 import java.util.concurrent.ExecutorService;
49 import java.util.concurrent.Executors;
50 import java.util.concurrent.Future;
51 import java.util.concurrent.TimeUnit;
52 import java.util.concurrent.atomic.AtomicInteger;
53 import org.junit.After;
54 import org.junit.Before;
55 import org.junit.Test;
56 import org.junit.runner.RunWith;
57 import org.junit.runners.JUnit4;
58 import org.mockito.Mock;
59 import org.mockito.MockitoAnnotations;
60 
61 /**
62  * Integration test for various forms of cancellation and deadline propagation.
63  */
64 @RunWith(JUnit4.class)
65 public class CascadingTest {
66 
67   @Mock
68   TestServiceGrpc.TestServiceImplBase service;
69   private ManagedChannel channel;
70   private Server server;
71   private CountDownLatch observedCancellations;
72   private CountDownLatch receivedCancellations;
73   private TestServiceGrpc.TestServiceBlockingStub blockingStub;
74   private TestServiceGrpc.TestServiceStub asyncStub;
75   private TestServiceGrpc.TestServiceFutureStub futureStub;
76   private ExecutorService otherWork;
77 
78   @Before
setUp()79   public void setUp() throws Exception {
80     MockitoAnnotations.initMocks(this);
81     // Use a cached thread pool as we need a thread for each blocked call
82     otherWork = Executors.newCachedThreadPool();
83     channel = InProcessChannelBuilder.forName("channel").executor(otherWork).build();
84     blockingStub = TestServiceGrpc.newBlockingStub(channel);
85     asyncStub = TestServiceGrpc.newStub(channel);
86     futureStub = TestServiceGrpc.newFutureStub(channel);
87   }
88 
89   @After
tearDown()90   public void tearDown() {
91     channel.shutdownNow();
92     server.shutdownNow();
93     otherWork.shutdownNow();
94   }
95 
96   /**
97    * Test {@link Context} cancellation propagates from the first node in the call chain all the way
98    * to the last.
99    */
100   @Test
testCascadingCancellationViaOuterContextCancellation()101   public void testCascadingCancellationViaOuterContextCancellation() throws Exception {
102     observedCancellations = new CountDownLatch(2);
103     receivedCancellations = new CountDownLatch(3);
104     Future<?> chainReady = startChainingServer(3);
105     CancellableContext context = Context.current().withCancellation();
106     Future<SimpleResponse> future;
107     Context prevContext = context.attach();
108     try {
109       future = futureStub.unaryCall(SimpleRequest.getDefaultInstance());
110     } finally {
111       context.detach(prevContext);
112     }
113     chainReady.get(5, TimeUnit.SECONDS);
114 
115     context.cancel(null);
116     try {
117       future.get(5, TimeUnit.SECONDS);
118       fail("Expected cancellation");
119     } catch (ExecutionException ex) {
120       Status status = Status.fromThrowable(ex);
121       assertEquals(Status.Code.CANCELLED, status.getCode());
122 
123       // Should have observed 2 cancellations responses from downstream servers
124       if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
125         fail("Expected number of cancellations not observed by clients");
126       }
127       if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
128         fail("Expected number of cancellations to be received by servers not observed");
129       }
130     }
131   }
132 
133   /**
134    * Test that cancellation via call cancellation propagates down the call.
135    */
136   @Test
testCascadingCancellationViaRpcCancel()137   public void testCascadingCancellationViaRpcCancel() throws Exception {
138     observedCancellations = new CountDownLatch(2);
139     receivedCancellations = new CountDownLatch(3);
140     Future<?> chainReady = startChainingServer(3);
141     Future<SimpleResponse> future = futureStub.unaryCall(SimpleRequest.getDefaultInstance());
142     chainReady.get(5, TimeUnit.SECONDS);
143 
144     future.cancel(true);
145     assertTrue(future.isCancelled());
146     if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
147       fail("Expected number of cancellations not observed by clients");
148     }
149     if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
150       fail("Expected number of cancellations to be received by servers not observed");
151     }
152   }
153 
154   /**
155    * Test that when RPC cancellation propagates up a call chain, the cancellation of the parent
156    * RPC triggers cancellation of all of its children.
157    */
158   @Test
testCascadingCancellationViaLeafFailure()159   public void testCascadingCancellationViaLeafFailure() throws Exception {
160     // All nodes (15) except one edge of the tree (4) will be cancelled.
161     observedCancellations = new CountDownLatch(11);
162     receivedCancellations = new CountDownLatch(11);
163     startCallTreeServer(3);
164     try {
165       // Use response size limit to control tree nodeCount.
166       blockingStub.unaryCall(Messages.SimpleRequest.newBuilder().setResponseSize(3).build());
167       fail("Expected abort");
168     } catch (StatusRuntimeException sre) {
169       // Wait for the workers to finish
170       Status status = sre.getStatus();
171       // Outermost caller observes ABORTED propagating up from the failing leaf,
172       // The descendant RPCs are cancelled so they receive CANCELLED.
173       assertEquals(Status.Code.ABORTED, status.getCode());
174 
175       if (!observedCancellations.await(5, TimeUnit.SECONDS)) {
176         fail("Expected number of cancellations not observed by clients");
177       }
178       if (!receivedCancellations.await(5, TimeUnit.SECONDS)) {
179         fail("Expected number of cancellations to be received by servers not observed");
180       }
181     }
182   }
183 
184   @Test
testDeadlinePropagation()185   public void testDeadlinePropagation() throws Exception {
186     final AtomicInteger recursionDepthRemaining = new AtomicInteger(3);
187     final SettableFuture<Deadline> finalDeadline = SettableFuture.create();
188     class DeadlineSaver extends TestServiceGrpc.TestServiceImplBase {
189       @Override
190       public void unaryCall(final SimpleRequest request,
191           final StreamObserver<SimpleResponse> responseObserver) {
192         Context.currentContextExecutor(otherWork).execute(new Runnable() {
193           @Override
194           public void run() {
195             try {
196               if (recursionDepthRemaining.decrementAndGet() == 0) {
197                 finalDeadline.set(Context.current().getDeadline());
198                 responseObserver.onNext(SimpleResponse.getDefaultInstance());
199               } else {
200                 responseObserver.onNext(blockingStub.unaryCall(request));
201               }
202               responseObserver.onCompleted();
203             } catch (Exception ex) {
204               responseObserver.onError(ex);
205             }
206           }
207         });
208       }
209     }
210 
211     server = InProcessServerBuilder.forName("channel").executor(otherWork)
212         .addService(new DeadlineSaver())
213         .build().start();
214 
215     Deadline initialDeadline = Deadline.after(1, TimeUnit.MINUTES);
216     blockingStub.withDeadline(initialDeadline).unaryCall(SimpleRequest.getDefaultInstance());
217     assertNotSame(initialDeadline, finalDeadline);
218     // Since deadline is re-calculated at each hop, some variance is acceptable and expected.
219     assertAbout(deadline())
220         .that(finalDeadline.get()).isWithin(1, TimeUnit.SECONDS).of(initialDeadline);
221   }
222 
223   /**
224    * Create a chain of client to server calls which can be cancelled top down.
225    *
226    * @return a Future that completes when call chain is created
227    */
startChainingServer(final int depthThreshold)228   private Future<?> startChainingServer(final int depthThreshold) throws IOException {
229     final AtomicInteger serversReady = new AtomicInteger();
230     final SettableFuture<Void> chainReady = SettableFuture.create();
231     class ChainingService extends TestServiceGrpc.TestServiceImplBase {
232       @Override
233       public void unaryCall(final SimpleRequest request,
234           final StreamObserver<SimpleResponse> responseObserver) {
235         ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(new Runnable() {
236           @Override
237           public void run() {
238             receivedCancellations.countDown();
239           }
240         });
241         if (serversReady.incrementAndGet() == depthThreshold) {
242           // Stop recursion
243           chainReady.set(null);
244           return;
245         }
246 
247         Context.currentContextExecutor(otherWork).execute(new Runnable() {
248           @Override
249           public void run() {
250             try {
251               blockingStub.unaryCall(request);
252             } catch (StatusRuntimeException e) {
253               Status status = e.getStatus();
254               if (status.getCode() == Status.Code.CANCELLED) {
255                 observedCancellations.countDown();
256               } else {
257                 responseObserver.onError(e);
258               }
259             }
260           }
261         });
262       }
263     }
264 
265     server = InProcessServerBuilder.forName("channel").executor(otherWork)
266         .addService(new ChainingService())
267         .build().start();
268     return chainReady;
269   }
270 
271   /**
272    * Create a tree of client to server calls where each received call on the server
273    * fans out to two downstream calls. Uses SimpleRequest.response_size to limit the nodeCount
274    * of the tree. One of the leaves will ABORT to trigger cancellation back up to tree.
275    */
startCallTreeServer(int depthThreshold)276   private void startCallTreeServer(int depthThreshold) throws IOException {
277     final AtomicInteger nodeCount = new AtomicInteger((2 << depthThreshold) - 1);
278     server = InProcessServerBuilder.forName("channel").executor(otherWork).addService(
279         ServerInterceptors.intercept(service,
280             new ServerInterceptor() {
281               @Override
282               public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
283                   final ServerCall<ReqT, RespT> call,
284                   Metadata headers,
285                   ServerCallHandler<ReqT, RespT> next) {
286                 // Respond with the headers but nothing else.
287                 call.sendHeaders(new Metadata());
288                 call.request(1);
289                 return new ServerCall.Listener<ReqT>() {
290                   @Override
291                   public void onMessage(final ReqT message) {
292                     Messages.SimpleRequest req = (Messages.SimpleRequest) message;
293                     if (nodeCount.decrementAndGet() == 0) {
294                       // we are in the final leaf node so trigger an ABORT upwards
295                       Context.currentContextExecutor(otherWork).execute(new Runnable() {
296                         @Override
297                         public void run() {
298                           call.close(Status.ABORTED, new Metadata());
299                         }
300                       });
301                     } else if (req.getResponseSize() != 0) {
302                       // We are in a non leaf node so fire off two requests
303                       req = req.toBuilder().setResponseSize(req.getResponseSize() - 1).build();
304                       for (int i = 0; i < 2; i++) {
305                         asyncStub.unaryCall(req,
306                             new StreamObserver<Messages.SimpleResponse>() {
307                               @Override
308                               public void onNext(Messages.SimpleResponse value) {
309                               }
310 
311                               @Override
312                               public void onError(Throwable t) {
313                                 Status status = Status.fromThrowable(t);
314                                 if (status.getCode() == Status.Code.CANCELLED) {
315                                   observedCancellations.countDown();
316                                 }
317                                 // Propagate closure upwards.
318                                 try {
319                                   call.close(status, new Metadata());
320                                 } catch (IllegalStateException t2) {
321                                   // Ignore error if already closed.
322                                 }
323                               }
324 
325                               @Override
326                               public void onCompleted() {
327                               }
328                             });
329                       }
330                     }
331                   }
332 
333                   @Override
334                   public void onCancel() {
335                     receivedCancellations.countDown();
336                   }
337                 };
338               }
339             })
340     ).build();
341     server.start();
342   }
343 }
344