• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static java.util.Arrays.asList;
20 import static java.util.Collections.shuffle;
21 import static java.util.Collections.singletonList;
22 import static java.util.concurrent.Executors.newFixedThreadPool;
23 import static java.util.concurrent.TimeUnit.SECONDS;
24 
25 import com.google.common.annotations.VisibleForTesting;
26 import com.google.common.base.Joiner;
27 import com.google.common.base.Objects;
28 import com.google.common.base.Preconditions;
29 import com.google.common.base.Splitter;
30 import com.google.common.collect.Iterators;
31 import com.google.common.util.concurrent.Futures;
32 import com.google.common.util.concurrent.ListenableFuture;
33 import com.google.common.util.concurrent.ListeningExecutorService;
34 import com.google.common.util.concurrent.MoreExecutors;
35 import io.grpc.ManagedChannel;
36 import io.grpc.ManagedChannelBuilder;
37 import io.grpc.Server;
38 import io.grpc.ServerBuilder;
39 import io.grpc.Status;
40 import io.grpc.StatusException;
41 import io.grpc.netty.GrpcSslContexts;
42 import io.grpc.netty.NegotiationType;
43 import io.grpc.netty.NettyChannelBuilder;
44 import io.grpc.stub.StreamObserver;
45 import io.grpc.testing.TlsTesting;
46 import io.netty.handler.ssl.SslContext;
47 import java.io.IOException;
48 import java.net.InetAddress;
49 import java.net.InetSocketAddress;
50 import java.net.URI;
51 import java.net.URISyntaxException;
52 import java.net.UnknownHostException;
53 import java.util.ArrayList;
54 import java.util.Collections;
55 import java.util.Iterator;
56 import java.util.List;
57 import java.util.Locale;
58 import java.util.Map;
59 import java.util.concurrent.ConcurrentHashMap;
60 import java.util.logging.Level;
61 import java.util.logging.Logger;
62 
63 /**
64  * A stress test client following the
65  * <a href="https://github.com/grpc/grpc/blob/master/tools/run_tests/stress_test/STRESS_CLIENT_SPEC.md">
66  * specifications</a> of the gRPC stress testing framework.
67  */
68 public class StressTestClient {
69 
70   private static final Logger log = Logger.getLogger(StressTestClient.class.getName());
71 
72   /**
73    * The main application allowing this client to be launched from the command line.
74    */
main(String... args)75   public static void main(String... args) throws Exception {
76     final StressTestClient client = new StressTestClient();
77     client.parseArgs(args);
78 
79     // Attempt an orderly shutdown, if the JVM is shutdown via a signal.
80     Runtime.getRuntime().addShutdownHook(new Thread() {
81       @Override
82       public void run() {
83         client.shutdown();
84       }
85     });
86 
87     try {
88       client.startMetricsService();
89       client.runStressTest();
90       client.blockUntilStressTestComplete();
91     } catch (Exception e) {
92       log.log(Level.WARNING, "The stress test client encountered an error!", e);
93     } finally {
94       client.shutdown();
95     }
96   }
97 
98   private static final int WORKER_GRACE_PERIOD_SECS = 30;
99 
100   private List<InetSocketAddress> addresses =
101       singletonList(new InetSocketAddress("localhost", 8080));
102   private List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>();
103 
104   private String serverHostOverride;
105   private boolean useTls = false;
106   private boolean useTestCa = false;
107   private int durationSecs = -1;
108   private int channelsPerServer = 1;
109   private int stubsPerChannel = 1;
110   private int metricsPort = 8081;
111 
112   private Server metricsServer;
113   private final Map<String, Metrics.GaugeResponse> gauges =
114       new ConcurrentHashMap<>();
115 
116   private volatile boolean shutdown;
117 
118   /**
119    * List of futures that {@link #blockUntilStressTestComplete()} waits for.
120    */
121   private final List<ListenableFuture<?>> workerFutures =
122       new ArrayList<>();
123   private final List<ManagedChannel> channels = new ArrayList<>();
124   private ListeningExecutorService threadpool;
125 
126   @VisibleForTesting
parseArgs(String[] args)127   void parseArgs(String[] args) {
128     boolean usage = false;
129     String serverAddresses = "";
130     for (String arg : args) {
131       if (!arg.startsWith("--")) {
132         System.err.println("All arguments must start with '--': " + arg);
133         usage = true;
134         break;
135       }
136       String[] parts = arg.substring(2).split("=", 2);
137       String key = parts[0];
138       if ("help".equals(key)) {
139         usage = true;
140         break;
141       }
142       if (parts.length != 2) {
143         System.err.println("All arguments must be of the form --arg=value");
144         usage = true;
145         break;
146       }
147       String value = parts[1];
148       if ("server_addresses".equals(key)) {
149         // May need to apply server host overrides to the addresses, so delay processing
150         serverAddresses = value;
151       } else if ("server_host_override".equals(key)) {
152         serverHostOverride = value;
153       } else if ("use_tls".equals(key)) {
154         useTls = Boolean.parseBoolean(value);
155       } else if ("use_test_ca".equals(key)) {
156         useTestCa = Boolean.parseBoolean(value);
157       } else if ("test_cases".equals(key)) {
158         testCaseWeightPairs = parseTestCases(value);
159       } else if ("test_duration_secs".equals(key)) {
160         durationSecs = Integer.valueOf(value);
161       } else if ("num_channels_per_server".equals(key)) {
162         channelsPerServer = Integer.valueOf(value);
163       } else if ("num_stubs_per_channel".equals(key)) {
164         stubsPerChannel = Integer.valueOf(value);
165       } else if ("metrics_port".equals(key)) {
166         metricsPort = Integer.valueOf(value);
167       } else {
168         System.err.println("Unknown argument: " + key);
169         usage = true;
170         break;
171       }
172     }
173 
174     if (!usage && !serverAddresses.isEmpty()) {
175       addresses = parseServerAddresses(serverAddresses);
176       usage = addresses.isEmpty();
177     }
178 
179     if (usage) {
180       StressTestClient c = new StressTestClient();
181       System.err.println(
182           "Usage: [ARGS...]"
183               + "\n"
184               + "\n  --server_host_override=HOST    Claimed identification expected of server."
185               + "\n                                 Defaults to server host"
186               + "\n  --server_addresses=<name_1>:<port_1>,<name_2>:<port_2>...<name_N>:<port_N>"
187               + "\n    Default: " + serverAddressesToString(c.addresses)
188               + "\n  --test_cases=<testcase_1:w_1>,<testcase_2:w_2>...<testcase_n:w_n>"
189               + "\n    List of <testcase,weight> tuples. Weight is the relative frequency at which"
190               + " testcase is run."
191               + "\n    Valid Testcases:"
192               + validTestCasesHelpText()
193               + "\n  --use_tls=true|false           Whether to use TLS. Default: " + c.useTls
194               + "\n  --use_test_ca=true|false       Whether to trust our fake CA. Requires"
195               + " --use_tls=true"
196               + "\n                                 to have effect. Default: " + c.useTestCa
197               + "\n  --test_duration_secs=SECONDS   '-1' for no limit. Default: " + c.durationSecs
198               + "\n  --num_channels_per_server=INT  Number of connections to each server address."
199               + " Default: " + c.channelsPerServer
200               + "\n  --num_stubs_per_channel=INT    Default: " + c.stubsPerChannel
201               + "\n  --metrics_port=PORT            Listening port of the metrics server."
202               + " Default: " + c.metricsPort
203       );
204       System.exit(1);
205     }
206   }
207 
208   @VisibleForTesting
startMetricsService()209   void startMetricsService() throws IOException {
210     Preconditions.checkState(!shutdown, "client was shutdown.");
211 
212     metricsServer = ServerBuilder.forPort(metricsPort)
213         .addService(new MetricsServiceImpl())
214         .build()
215         .start();
216   }
217 
218   @VisibleForTesting
runStressTest()219   void runStressTest() throws Exception {
220     Preconditions.checkState(!shutdown, "client was shutdown.");
221     if (testCaseWeightPairs.isEmpty()) {
222       return;
223     }
224 
225     int numChannels = addresses.size() * channelsPerServer;
226     int numThreads = numChannels * stubsPerChannel;
227     threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads));
228     int serverIdx = -1;
229     for (InetSocketAddress address : addresses) {
230       serverIdx++;
231       for (int i = 0; i < channelsPerServer; i++) {
232         ManagedChannel channel = createChannel(address);
233         channels.add(channel);
234         for (int j = 0; j < stubsPerChannel; j++) {
235           String gaugeName = String.format(
236               Locale.US, "/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j);
237           Worker worker =
238               new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName);
239 
240           workerFutures.add(threadpool.submit(worker));
241         }
242       }
243     }
244   }
245 
246   @VisibleForTesting
blockUntilStressTestComplete()247   void blockUntilStressTestComplete() throws Exception {
248     Preconditions.checkState(!shutdown, "client was shutdown.");
249 
250     ListenableFuture<?> f = Futures.allAsList(workerFutures);
251     if (durationSecs == -1) {
252       // '-1' indicates that the stress test runs until terminated by the user.
253       f.get();
254     } else {
255       f.get(durationSecs + WORKER_GRACE_PERIOD_SECS, SECONDS);
256     }
257   }
258 
259   @VisibleForTesting
shutdown()260   void shutdown() {
261     if (shutdown) {
262       return;
263     }
264     shutdown = true;
265 
266     for (ManagedChannel ch : channels) {
267       try {
268         ch.shutdownNow();
269         ch.awaitTermination(1, SECONDS);
270       } catch (Throwable t) {
271         log.log(Level.WARNING, "Error shutting down channel!", t);
272       }
273     }
274 
275     try {
276       metricsServer.shutdownNow();
277     } catch (Throwable t) {
278       log.log(Level.WARNING, "Error shutting down metrics service!", t);
279     }
280 
281     try {
282       if (threadpool != null) {
283         threadpool.shutdownNow();
284       }
285     } catch (Throwable t) {
286       log.log(Level.WARNING, "Error shutting down threadpool.", t);
287     }
288   }
289 
290   @VisibleForTesting
getMetricServerPort()291   int getMetricServerPort() {
292     return metricsServer.getPort();
293   }
294 
parseServerAddresses(String addressesStr)295   private List<InetSocketAddress> parseServerAddresses(String addressesStr) {
296     List<InetSocketAddress> addresses = new ArrayList<>();
297 
298     for (List<String> namePort : parseCommaSeparatedTuples(addressesStr)) {
299       InetAddress address;
300       String name = namePort.get(0);
301       int port = Integer.valueOf(namePort.get(1));
302       try {
303         address = InetAddress.getByName(name);
304         if (serverHostOverride != null) {
305           // Force the hostname to match the cert the server uses.
306           address = InetAddress.getByAddress(serverHostOverride, address.getAddress());
307         }
308       } catch (UnknownHostException ex) {
309         throw new RuntimeException(ex);
310       }
311       addresses.add(new InetSocketAddress(address, port));
312     }
313 
314     return addresses;
315   }
316 
parseTestCases(String testCasesStr)317   private static List<TestCaseWeightPair> parseTestCases(String testCasesStr) {
318     List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>();
319 
320     for (List<String> nameWeight : parseCommaSeparatedTuples(testCasesStr)) {
321       TestCases testCase = TestCases.fromString(nameWeight.get(0));
322       int weight = Integer.valueOf(nameWeight.get(1));
323       testCaseWeightPairs.add(new TestCaseWeightPair(testCase, weight));
324     }
325 
326     return testCaseWeightPairs;
327   }
328 
parseCommaSeparatedTuples(String str)329   private static List<List<String>> parseCommaSeparatedTuples(String str) {
330     List<List<String>> tuples = new ArrayList<>();
331     for (String tupleStr : Splitter.on(',').split(str)) {
332       int splitIdx = tupleStr.lastIndexOf(':');
333       if (splitIdx == -1) {
334         throw new IllegalArgumentException("Illegal tuple format: '" + tupleStr + "'");
335       }
336       String part0 = tupleStr.substring(0, splitIdx);
337       String part1 = tupleStr.substring(splitIdx + 1);
338       tuples.add(asList(part0, part1));
339     }
340     return tuples;
341   }
342 
createChannel(InetSocketAddress address)343   private ManagedChannel createChannel(InetSocketAddress address) {
344     SslContext sslContext = null;
345     if (useTestCa) {
346       try {
347         sslContext = GrpcSslContexts.forClient().trustManager(
348             TlsTesting.loadCert("ca.pem")).build();
349       } catch (Exception ex) {
350         throw new RuntimeException(ex);
351       }
352     }
353     return NettyChannelBuilder.forAddress(address)
354         .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT)
355         .sslContext(sslContext)
356         .build();
357   }
358 
serverAddressesToString(List<InetSocketAddress> addresses)359   private static String serverAddressesToString(List<InetSocketAddress> addresses) {
360     List<String> tmp = new ArrayList<>();
361     for (InetSocketAddress address : addresses) {
362       URI uri;
363       try {
364         uri = new URI(null, null, address.getHostName(), address.getPort(), null, null, null);
365       } catch (URISyntaxException e) {
366         throw new RuntimeException(e);
367       }
368       tmp.add(uri.getAuthority());
369     }
370     return Joiner.on(',').join(tmp);
371   }
372 
validTestCasesHelpText()373   private static String validTestCasesHelpText() {
374     StringBuilder builder = new StringBuilder();
375     for (TestCases testCase : TestCases.values()) {
376       String strTestcase = testCase.name().toLowerCase();
377       builder.append("\n      ")
378           .append(strTestcase)
379           .append(": ")
380           .append(testCase.description());
381     }
382     return builder.toString();
383   }
384 
385   /**
386    * A stress test worker. Every stub has its own stress test worker.
387    */
388   private class Worker implements Runnable {
389 
390     // Interval at which the QPS stats of metrics service are updated.
391     private static final long METRICS_COLLECTION_INTERVAL_SECS = 5;
392 
393     private final ManagedChannel channel;
394     private final List<TestCaseWeightPair> testCaseWeightPairs;
395     private final Integer durationSec;
396     private final String gaugeName;
397 
Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, int durationSec, String gaugeName)398     Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs,
399         int durationSec, String gaugeName) {
400       Preconditions.checkArgument(durationSec >= -1, "durationSec must be gte -1.");
401       this.channel = Preconditions.checkNotNull(channel, "channel");
402       this.testCaseWeightPairs =
403           Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs");
404       this.durationSec = durationSec == -1 ? null : durationSec;
405       this.gaugeName = Preconditions.checkNotNull(gaugeName, "gaugeName");
406     }
407 
408     @Override
run()409     public void run() {
410       // Simplify debugging if the worker crashes / never terminates.
411       Thread.currentThread().setName(gaugeName);
412 
413       Tester tester = new Tester();
414       tester.setUp();
415       WeightedTestCaseSelector testCaseSelector = new WeightedTestCaseSelector(testCaseWeightPairs);
416       Long endTime = durationSec == null ? null : System.nanoTime() + SECONDS.toNanos(durationSecs);
417       long lastMetricsCollectionTime = initLastMetricsCollectionTime();
418       // Number of interop testcases run since the last time metrics have been updated.
419       long testCasesSinceLastMetricsCollection = 0;
420 
421       while (!Thread.currentThread().isInterrupted() && !shutdown
422           && (endTime == null || endTime - System.nanoTime() > 0)) {
423         try {
424           runTestCase(tester, testCaseSelector.nextTestCase());
425         } catch (Exception e) {
426           throw new RuntimeException(e);
427         }
428 
429         testCasesSinceLastMetricsCollection++;
430 
431         double durationSecs = computeDurationSecs(lastMetricsCollectionTime);
432         if (durationSecs >= METRICS_COLLECTION_INTERVAL_SECS) {
433           long qps = (long) Math.ceil(testCasesSinceLastMetricsCollection / durationSecs);
434 
435           Metrics.GaugeResponse gauge = Metrics.GaugeResponse
436               .newBuilder()
437               .setName(gaugeName)
438               .setLongValue(qps)
439               .build();
440 
441           gauges.put(gaugeName, gauge);
442 
443           lastMetricsCollectionTime = System.nanoTime();
444           testCasesSinceLastMetricsCollection = 0;
445         }
446       }
447     }
448 
initLastMetricsCollectionTime()449     private long initLastMetricsCollectionTime() {
450       return System.nanoTime() - SECONDS.toNanos(METRICS_COLLECTION_INTERVAL_SECS);
451     }
452 
computeDurationSecs(long lastMetricsCollectionTime)453     private double computeDurationSecs(long lastMetricsCollectionTime) {
454       return (System.nanoTime() - lastMetricsCollectionTime) / 1000000000.0;
455     }
456 
runTestCase(Tester tester, TestCases testCase)457     private void runTestCase(Tester tester, TestCases testCase) throws Exception {
458       // TODO(buchgr): Implement tests requiring auth, once C++ supports it.
459       switch (testCase) {
460         case EMPTY_UNARY:
461           tester.emptyUnary();
462           break;
463 
464         case LARGE_UNARY:
465           tester.largeUnary();
466           break;
467 
468         case CLIENT_STREAMING:
469           tester.clientStreaming();
470           break;
471 
472         case SERVER_STREAMING:
473           tester.serverStreaming();
474           break;
475 
476         case PING_PONG:
477           tester.pingPong();
478           break;
479 
480         case EMPTY_STREAM:
481           tester.emptyStream();
482           break;
483 
484         case UNIMPLEMENTED_METHOD: {
485           tester.unimplementedMethod();
486           break;
487         }
488 
489         case UNIMPLEMENTED_SERVICE: {
490           tester.unimplementedService();
491           break;
492         }
493 
494         case CANCEL_AFTER_BEGIN: {
495           tester.cancelAfterBegin();
496           break;
497         }
498 
499         case CANCEL_AFTER_FIRST_RESPONSE: {
500           tester.cancelAfterFirstResponse();
501           break;
502         }
503 
504         case TIMEOUT_ON_SLEEPING_SERVER: {
505           tester.timeoutOnSleepingServer();
506           break;
507         }
508 
509         default:
510           throw new IllegalArgumentException("Unknown test case: " + testCase);
511       }
512     }
513 
514     class Tester extends AbstractInteropTest {
515       @Override
createChannel()516       protected ManagedChannel createChannel() {
517         return Worker.this.channel;
518       }
519 
520       @Override
createChannelBuilder()521       protected ManagedChannelBuilder<?> createChannelBuilder() {
522         throw new UnsupportedOperationException();
523       }
524 
525       @Override
operationTimeoutMillis()526       protected int operationTimeoutMillis() {
527         // Don't enforce a timeout when using the interop tests for the stress test client.
528         // Fixes https://github.com/grpc/grpc-java/issues/1812
529         return Integer.MAX_VALUE;
530       }
531 
532       @Override
metricsExpected()533       protected boolean metricsExpected() {
534         // TODO(zhangkun83): we may want to enable the real google Instrumentation implementation in
535         // stress tests.
536         return false;
537       }
538     }
539 
540     class WeightedTestCaseSelector {
541       /**
542        * Randomly shuffled and cyclic sequence that contains each testcase proportionally
543        * to its weight.
544        */
545       final Iterator<TestCases> testCases;
546 
WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs)547       WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs) {
548         Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs");
549         Preconditions.checkArgument(testCaseWeightPairs.size() > 0);
550 
551         List<TestCases> testCases = new ArrayList<>();
552         for (TestCaseWeightPair testCaseWeightPair : testCaseWeightPairs) {
553           for (int i = 0; i < testCaseWeightPair.weight; i++) {
554             testCases.add(testCaseWeightPair.testCase);
555           }
556         }
557 
558         shuffle(testCases);
559 
560         this.testCases = Iterators.cycle(testCases);
561       }
562 
nextTestCase()563       TestCases nextTestCase() {
564         return testCases.next();
565       }
566     }
567   }
568 
569   /**
570    * Service that exports the QPS metrics of the stress test.
571    */
572   private class MetricsServiceImpl extends MetricsServiceGrpc.MetricsServiceImplBase {
573 
574     @Override
getAllGauges(Metrics.EmptyMessage request, StreamObserver<Metrics.GaugeResponse> responseObserver)575     public void getAllGauges(Metrics.EmptyMessage request,
576         StreamObserver<Metrics.GaugeResponse> responseObserver) {
577       for (Metrics.GaugeResponse gauge : gauges.values()) {
578         responseObserver.onNext(gauge);
579       }
580       responseObserver.onCompleted();
581     }
582 
583     @Override
getGauge(Metrics.GaugeRequest request, StreamObserver<Metrics.GaugeResponse> responseObserver)584     public void getGauge(Metrics.GaugeRequest request,
585         StreamObserver<Metrics.GaugeResponse> responseObserver) {
586       String gaugeName = request.getName();
587       Metrics.GaugeResponse gauge = gauges.get(gaugeName);
588       if (gauge != null) {
589         responseObserver.onNext(gauge);
590         responseObserver.onCompleted();
591       } else {
592         responseObserver.onError(new StatusException(Status.NOT_FOUND));
593       }
594     }
595   }
596 
597   @VisibleForTesting
598   static class TestCaseWeightPair {
599     final TestCases testCase;
600     final int weight;
601 
TestCaseWeightPair(TestCases testCase, int weight)602     TestCaseWeightPair(TestCases testCase, int weight) {
603       Preconditions.checkArgument(weight >= 0, "weight must be positive.");
604       this.testCase = Preconditions.checkNotNull(testCase, "testCase");
605       this.weight = weight;
606     }
607 
608     @Override
equals(Object other)609     public boolean equals(Object other) {
610       if (!(other instanceof TestCaseWeightPair)) {
611         return false;
612       }
613       TestCaseWeightPair that = (TestCaseWeightPair) other;
614       return testCase.equals(that.testCase) && weight == that.weight;
615     }
616 
617     @Override
hashCode()618     public int hashCode() {
619       return Objects.hashCode(testCase, weight);
620     }
621   }
622 
623   @VisibleForTesting
addresses()624   List<InetSocketAddress> addresses() {
625     return Collections.unmodifiableList(addresses);
626   }
627 
628   @VisibleForTesting
serverHostOverride()629   String serverHostOverride() {
630     return serverHostOverride;
631   }
632 
633   @VisibleForTesting
useTls()634   boolean useTls() {
635     return useTls;
636   }
637 
638   @VisibleForTesting
useTestCa()639   boolean useTestCa() {
640     return useTestCa;
641   }
642 
643   @VisibleForTesting
testCaseWeightPairs()644   List<TestCaseWeightPair> testCaseWeightPairs() {
645     return testCaseWeightPairs;
646   }
647 
648   @VisibleForTesting
durationSecs()649   int durationSecs() {
650     return durationSecs;
651   }
652 
653   @VisibleForTesting
channelsPerServer()654   int channelsPerServer() {
655     return channelsPerServer;
656   }
657 
658   @VisibleForTesting
stubsPerChannel()659   int stubsPerChannel() {
660     return stubsPerChannel;
661   }
662 
663   @VisibleForTesting
metricsPort()664   int metricsPort() {
665     return metricsPort;
666   }
667 }
668