1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.testing.integration; 18 19 import static java.util.Arrays.asList; 20 import static java.util.Collections.shuffle; 21 import static java.util.Collections.singletonList; 22 import static java.util.concurrent.Executors.newFixedThreadPool; 23 import static java.util.concurrent.TimeUnit.SECONDS; 24 25 import com.google.common.annotations.VisibleForTesting; 26 import com.google.common.base.Joiner; 27 import com.google.common.base.Objects; 28 import com.google.common.base.Preconditions; 29 import com.google.common.base.Splitter; 30 import com.google.common.collect.Iterators; 31 import com.google.common.util.concurrent.Futures; 32 import com.google.common.util.concurrent.ListenableFuture; 33 import com.google.common.util.concurrent.ListeningExecutorService; 34 import com.google.common.util.concurrent.MoreExecutors; 35 import io.grpc.ManagedChannel; 36 import io.grpc.ManagedChannelBuilder; 37 import io.grpc.Server; 38 import io.grpc.ServerBuilder; 39 import io.grpc.Status; 40 import io.grpc.StatusException; 41 import io.grpc.netty.GrpcSslContexts; 42 import io.grpc.netty.NegotiationType; 43 import io.grpc.netty.NettyChannelBuilder; 44 import io.grpc.stub.StreamObserver; 45 import io.grpc.testing.TlsTesting; 46 import io.netty.handler.ssl.SslContext; 47 import java.io.IOException; 48 import java.net.InetAddress; 49 import java.net.InetSocketAddress; 50 import java.net.URI; 51 import java.net.URISyntaxException; 52 import java.net.UnknownHostException; 53 import java.util.ArrayList; 54 import java.util.Collections; 55 import java.util.Iterator; 56 import java.util.List; 57 import java.util.Locale; 58 import java.util.Map; 59 import java.util.concurrent.ConcurrentHashMap; 60 import java.util.logging.Level; 61 import java.util.logging.Logger; 62 63 /** 64 * A stress test client following the 65 * <a href="https://github.com/grpc/grpc/blob/master/tools/run_tests/stress_test/STRESS_CLIENT_SPEC.md"> 66 * specifications</a> of the gRPC stress testing framework. 67 */ 68 public class StressTestClient { 69 70 private static final Logger log = Logger.getLogger(StressTestClient.class.getName()); 71 72 /** 73 * The main application allowing this client to be launched from the command line. 74 */ main(String... args)75 public static void main(String... args) throws Exception { 76 final StressTestClient client = new StressTestClient(); 77 client.parseArgs(args); 78 79 // Attempt an orderly shutdown, if the JVM is shutdown via a signal. 80 Runtime.getRuntime().addShutdownHook(new Thread() { 81 @Override 82 public void run() { 83 client.shutdown(); 84 } 85 }); 86 87 try { 88 client.startMetricsService(); 89 client.runStressTest(); 90 client.blockUntilStressTestComplete(); 91 } catch (Exception e) { 92 log.log(Level.WARNING, "The stress test client encountered an error!", e); 93 } finally { 94 client.shutdown(); 95 } 96 } 97 98 private static final int WORKER_GRACE_PERIOD_SECS = 30; 99 100 private List<InetSocketAddress> addresses = 101 singletonList(new InetSocketAddress("localhost", 8080)); 102 private List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>(); 103 104 private String serverHostOverride; 105 private boolean useTls = false; 106 private boolean useTestCa = false; 107 private int durationSecs = -1; 108 private int channelsPerServer = 1; 109 private int stubsPerChannel = 1; 110 private int metricsPort = 8081; 111 112 private Server metricsServer; 113 private final Map<String, Metrics.GaugeResponse> gauges = 114 new ConcurrentHashMap<>(); 115 116 private volatile boolean shutdown; 117 118 /** 119 * List of futures that {@link #blockUntilStressTestComplete()} waits for. 120 */ 121 private final List<ListenableFuture<?>> workerFutures = 122 new ArrayList<>(); 123 private final List<ManagedChannel> channels = new ArrayList<>(); 124 private ListeningExecutorService threadpool; 125 126 @VisibleForTesting parseArgs(String[] args)127 void parseArgs(String[] args) { 128 boolean usage = false; 129 String serverAddresses = ""; 130 for (String arg : args) { 131 if (!arg.startsWith("--")) { 132 System.err.println("All arguments must start with '--': " + arg); 133 usage = true; 134 break; 135 } 136 String[] parts = arg.substring(2).split("=", 2); 137 String key = parts[0]; 138 if ("help".equals(key)) { 139 usage = true; 140 break; 141 } 142 if (parts.length != 2) { 143 System.err.println("All arguments must be of the form --arg=value"); 144 usage = true; 145 break; 146 } 147 String value = parts[1]; 148 if ("server_addresses".equals(key)) { 149 // May need to apply server host overrides to the addresses, so delay processing 150 serverAddresses = value; 151 } else if ("server_host_override".equals(key)) { 152 serverHostOverride = value; 153 } else if ("use_tls".equals(key)) { 154 useTls = Boolean.parseBoolean(value); 155 } else if ("use_test_ca".equals(key)) { 156 useTestCa = Boolean.parseBoolean(value); 157 } else if ("test_cases".equals(key)) { 158 testCaseWeightPairs = parseTestCases(value); 159 } else if ("test_duration_secs".equals(key)) { 160 durationSecs = Integer.valueOf(value); 161 } else if ("num_channels_per_server".equals(key)) { 162 channelsPerServer = Integer.valueOf(value); 163 } else if ("num_stubs_per_channel".equals(key)) { 164 stubsPerChannel = Integer.valueOf(value); 165 } else if ("metrics_port".equals(key)) { 166 metricsPort = Integer.valueOf(value); 167 } else { 168 System.err.println("Unknown argument: " + key); 169 usage = true; 170 break; 171 } 172 } 173 174 if (!usage && !serverAddresses.isEmpty()) { 175 addresses = parseServerAddresses(serverAddresses); 176 usage = addresses.isEmpty(); 177 } 178 179 if (usage) { 180 StressTestClient c = new StressTestClient(); 181 System.err.println( 182 "Usage: [ARGS...]" 183 + "\n" 184 + "\n --server_host_override=HOST Claimed identification expected of server." 185 + "\n Defaults to server host" 186 + "\n --server_addresses=<name_1>:<port_1>,<name_2>:<port_2>...<name_N>:<port_N>" 187 + "\n Default: " + serverAddressesToString(c.addresses) 188 + "\n --test_cases=<testcase_1:w_1>,<testcase_2:w_2>...<testcase_n:w_n>" 189 + "\n List of <testcase,weight> tuples. Weight is the relative frequency at which" 190 + " testcase is run." 191 + "\n Valid Testcases:" 192 + validTestCasesHelpText() 193 + "\n --use_tls=true|false Whether to use TLS. Default: " + c.useTls 194 + "\n --use_test_ca=true|false Whether to trust our fake CA. Requires" 195 + " --use_tls=true" 196 + "\n to have effect. Default: " + c.useTestCa 197 + "\n --test_duration_secs=SECONDS '-1' for no limit. Default: " + c.durationSecs 198 + "\n --num_channels_per_server=INT Number of connections to each server address." 199 + " Default: " + c.channelsPerServer 200 + "\n --num_stubs_per_channel=INT Default: " + c.stubsPerChannel 201 + "\n --metrics_port=PORT Listening port of the metrics server." 202 + " Default: " + c.metricsPort 203 ); 204 System.exit(1); 205 } 206 } 207 208 @VisibleForTesting startMetricsService()209 void startMetricsService() throws IOException { 210 Preconditions.checkState(!shutdown, "client was shutdown."); 211 212 metricsServer = ServerBuilder.forPort(metricsPort) 213 .addService(new MetricsServiceImpl()) 214 .build() 215 .start(); 216 } 217 218 @VisibleForTesting runStressTest()219 void runStressTest() throws Exception { 220 Preconditions.checkState(!shutdown, "client was shutdown."); 221 if (testCaseWeightPairs.isEmpty()) { 222 return; 223 } 224 225 int numChannels = addresses.size() * channelsPerServer; 226 int numThreads = numChannels * stubsPerChannel; 227 threadpool = MoreExecutors.listeningDecorator(newFixedThreadPool(numThreads)); 228 int serverIdx = -1; 229 for (InetSocketAddress address : addresses) { 230 serverIdx++; 231 for (int i = 0; i < channelsPerServer; i++) { 232 ManagedChannel channel = createChannel(address); 233 channels.add(channel); 234 for (int j = 0; j < stubsPerChannel; j++) { 235 String gaugeName = String.format( 236 Locale.US, "/stress_test/server_%d/channel_%d/stub_%d/qps", serverIdx, i, j); 237 Worker worker = 238 new Worker(channel, testCaseWeightPairs, durationSecs, gaugeName); 239 240 workerFutures.add(threadpool.submit(worker)); 241 } 242 } 243 } 244 } 245 246 @VisibleForTesting blockUntilStressTestComplete()247 void blockUntilStressTestComplete() throws Exception { 248 Preconditions.checkState(!shutdown, "client was shutdown."); 249 250 ListenableFuture<?> f = Futures.allAsList(workerFutures); 251 if (durationSecs == -1) { 252 // '-1' indicates that the stress test runs until terminated by the user. 253 f.get(); 254 } else { 255 f.get(durationSecs + WORKER_GRACE_PERIOD_SECS, SECONDS); 256 } 257 } 258 259 @VisibleForTesting shutdown()260 void shutdown() { 261 if (shutdown) { 262 return; 263 } 264 shutdown = true; 265 266 for (ManagedChannel ch : channels) { 267 try { 268 ch.shutdownNow(); 269 ch.awaitTermination(1, SECONDS); 270 } catch (Throwable t) { 271 log.log(Level.WARNING, "Error shutting down channel!", t); 272 } 273 } 274 275 try { 276 metricsServer.shutdownNow(); 277 } catch (Throwable t) { 278 log.log(Level.WARNING, "Error shutting down metrics service!", t); 279 } 280 281 try { 282 if (threadpool != null) { 283 threadpool.shutdownNow(); 284 } 285 } catch (Throwable t) { 286 log.log(Level.WARNING, "Error shutting down threadpool.", t); 287 } 288 } 289 290 @VisibleForTesting getMetricServerPort()291 int getMetricServerPort() { 292 return metricsServer.getPort(); 293 } 294 parseServerAddresses(String addressesStr)295 private List<InetSocketAddress> parseServerAddresses(String addressesStr) { 296 List<InetSocketAddress> addresses = new ArrayList<>(); 297 298 for (List<String> namePort : parseCommaSeparatedTuples(addressesStr)) { 299 InetAddress address; 300 String name = namePort.get(0); 301 int port = Integer.valueOf(namePort.get(1)); 302 try { 303 address = InetAddress.getByName(name); 304 if (serverHostOverride != null) { 305 // Force the hostname to match the cert the server uses. 306 address = InetAddress.getByAddress(serverHostOverride, address.getAddress()); 307 } 308 } catch (UnknownHostException ex) { 309 throw new RuntimeException(ex); 310 } 311 addresses.add(new InetSocketAddress(address, port)); 312 } 313 314 return addresses; 315 } 316 parseTestCases(String testCasesStr)317 private static List<TestCaseWeightPair> parseTestCases(String testCasesStr) { 318 List<TestCaseWeightPair> testCaseWeightPairs = new ArrayList<>(); 319 320 for (List<String> nameWeight : parseCommaSeparatedTuples(testCasesStr)) { 321 TestCases testCase = TestCases.fromString(nameWeight.get(0)); 322 int weight = Integer.valueOf(nameWeight.get(1)); 323 testCaseWeightPairs.add(new TestCaseWeightPair(testCase, weight)); 324 } 325 326 return testCaseWeightPairs; 327 } 328 parseCommaSeparatedTuples(String str)329 private static List<List<String>> parseCommaSeparatedTuples(String str) { 330 List<List<String>> tuples = new ArrayList<>(); 331 for (String tupleStr : Splitter.on(',').split(str)) { 332 int splitIdx = tupleStr.lastIndexOf(':'); 333 if (splitIdx == -1) { 334 throw new IllegalArgumentException("Illegal tuple format: '" + tupleStr + "'"); 335 } 336 String part0 = tupleStr.substring(0, splitIdx); 337 String part1 = tupleStr.substring(splitIdx + 1); 338 tuples.add(asList(part0, part1)); 339 } 340 return tuples; 341 } 342 createChannel(InetSocketAddress address)343 private ManagedChannel createChannel(InetSocketAddress address) { 344 SslContext sslContext = null; 345 if (useTestCa) { 346 try { 347 sslContext = GrpcSslContexts.forClient().trustManager( 348 TlsTesting.loadCert("ca.pem")).build(); 349 } catch (Exception ex) { 350 throw new RuntimeException(ex); 351 } 352 } 353 return NettyChannelBuilder.forAddress(address) 354 .negotiationType(useTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT) 355 .sslContext(sslContext) 356 .build(); 357 } 358 serverAddressesToString(List<InetSocketAddress> addresses)359 private static String serverAddressesToString(List<InetSocketAddress> addresses) { 360 List<String> tmp = new ArrayList<>(); 361 for (InetSocketAddress address : addresses) { 362 URI uri; 363 try { 364 uri = new URI(null, null, address.getHostName(), address.getPort(), null, null, null); 365 } catch (URISyntaxException e) { 366 throw new RuntimeException(e); 367 } 368 tmp.add(uri.getAuthority()); 369 } 370 return Joiner.on(',').join(tmp); 371 } 372 validTestCasesHelpText()373 private static String validTestCasesHelpText() { 374 StringBuilder builder = new StringBuilder(); 375 for (TestCases testCase : TestCases.values()) { 376 String strTestcase = testCase.name().toLowerCase(); 377 builder.append("\n ") 378 .append(strTestcase) 379 .append(": ") 380 .append(testCase.description()); 381 } 382 return builder.toString(); 383 } 384 385 /** 386 * A stress test worker. Every stub has its own stress test worker. 387 */ 388 private class Worker implements Runnable { 389 390 // Interval at which the QPS stats of metrics service are updated. 391 private static final long METRICS_COLLECTION_INTERVAL_SECS = 5; 392 393 private final ManagedChannel channel; 394 private final List<TestCaseWeightPair> testCaseWeightPairs; 395 private final Integer durationSec; 396 private final String gaugeName; 397 Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, int durationSec, String gaugeName)398 Worker(ManagedChannel channel, List<TestCaseWeightPair> testCaseWeightPairs, 399 int durationSec, String gaugeName) { 400 Preconditions.checkArgument(durationSec >= -1, "durationSec must be gte -1."); 401 this.channel = Preconditions.checkNotNull(channel, "channel"); 402 this.testCaseWeightPairs = 403 Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); 404 this.durationSec = durationSec == -1 ? null : durationSec; 405 this.gaugeName = Preconditions.checkNotNull(gaugeName, "gaugeName"); 406 } 407 408 @Override run()409 public void run() { 410 // Simplify debugging if the worker crashes / never terminates. 411 Thread.currentThread().setName(gaugeName); 412 413 Tester tester = new Tester(); 414 tester.setUp(); 415 WeightedTestCaseSelector testCaseSelector = new WeightedTestCaseSelector(testCaseWeightPairs); 416 Long endTime = durationSec == null ? null : System.nanoTime() + SECONDS.toNanos(durationSecs); 417 long lastMetricsCollectionTime = initLastMetricsCollectionTime(); 418 // Number of interop testcases run since the last time metrics have been updated. 419 long testCasesSinceLastMetricsCollection = 0; 420 421 while (!Thread.currentThread().isInterrupted() && !shutdown 422 && (endTime == null || endTime - System.nanoTime() > 0)) { 423 try { 424 runTestCase(tester, testCaseSelector.nextTestCase()); 425 } catch (Exception e) { 426 throw new RuntimeException(e); 427 } 428 429 testCasesSinceLastMetricsCollection++; 430 431 double durationSecs = computeDurationSecs(lastMetricsCollectionTime); 432 if (durationSecs >= METRICS_COLLECTION_INTERVAL_SECS) { 433 long qps = (long) Math.ceil(testCasesSinceLastMetricsCollection / durationSecs); 434 435 Metrics.GaugeResponse gauge = Metrics.GaugeResponse 436 .newBuilder() 437 .setName(gaugeName) 438 .setLongValue(qps) 439 .build(); 440 441 gauges.put(gaugeName, gauge); 442 443 lastMetricsCollectionTime = System.nanoTime(); 444 testCasesSinceLastMetricsCollection = 0; 445 } 446 } 447 } 448 initLastMetricsCollectionTime()449 private long initLastMetricsCollectionTime() { 450 return System.nanoTime() - SECONDS.toNanos(METRICS_COLLECTION_INTERVAL_SECS); 451 } 452 computeDurationSecs(long lastMetricsCollectionTime)453 private double computeDurationSecs(long lastMetricsCollectionTime) { 454 return (System.nanoTime() - lastMetricsCollectionTime) / 1000000000.0; 455 } 456 runTestCase(Tester tester, TestCases testCase)457 private void runTestCase(Tester tester, TestCases testCase) throws Exception { 458 // TODO(buchgr): Implement tests requiring auth, once C++ supports it. 459 switch (testCase) { 460 case EMPTY_UNARY: 461 tester.emptyUnary(); 462 break; 463 464 case LARGE_UNARY: 465 tester.largeUnary(); 466 break; 467 468 case CLIENT_STREAMING: 469 tester.clientStreaming(); 470 break; 471 472 case SERVER_STREAMING: 473 tester.serverStreaming(); 474 break; 475 476 case PING_PONG: 477 tester.pingPong(); 478 break; 479 480 case EMPTY_STREAM: 481 tester.emptyStream(); 482 break; 483 484 case UNIMPLEMENTED_METHOD: { 485 tester.unimplementedMethod(); 486 break; 487 } 488 489 case UNIMPLEMENTED_SERVICE: { 490 tester.unimplementedService(); 491 break; 492 } 493 494 case CANCEL_AFTER_BEGIN: { 495 tester.cancelAfterBegin(); 496 break; 497 } 498 499 case CANCEL_AFTER_FIRST_RESPONSE: { 500 tester.cancelAfterFirstResponse(); 501 break; 502 } 503 504 case TIMEOUT_ON_SLEEPING_SERVER: { 505 tester.timeoutOnSleepingServer(); 506 break; 507 } 508 509 default: 510 throw new IllegalArgumentException("Unknown test case: " + testCase); 511 } 512 } 513 514 class Tester extends AbstractInteropTest { 515 @Override createChannel()516 protected ManagedChannel createChannel() { 517 return Worker.this.channel; 518 } 519 520 @Override createChannelBuilder()521 protected ManagedChannelBuilder<?> createChannelBuilder() { 522 throw new UnsupportedOperationException(); 523 } 524 525 @Override operationTimeoutMillis()526 protected int operationTimeoutMillis() { 527 // Don't enforce a timeout when using the interop tests for the stress test client. 528 // Fixes https://github.com/grpc/grpc-java/issues/1812 529 return Integer.MAX_VALUE; 530 } 531 532 @Override metricsExpected()533 protected boolean metricsExpected() { 534 // TODO(zhangkun83): we may want to enable the real google Instrumentation implementation in 535 // stress tests. 536 return false; 537 } 538 } 539 540 class WeightedTestCaseSelector { 541 /** 542 * Randomly shuffled and cyclic sequence that contains each testcase proportionally 543 * to its weight. 544 */ 545 final Iterator<TestCases> testCases; 546 WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs)547 WeightedTestCaseSelector(List<TestCaseWeightPair> testCaseWeightPairs) { 548 Preconditions.checkNotNull(testCaseWeightPairs, "testCaseWeightPairs"); 549 Preconditions.checkArgument(testCaseWeightPairs.size() > 0); 550 551 List<TestCases> testCases = new ArrayList<>(); 552 for (TestCaseWeightPair testCaseWeightPair : testCaseWeightPairs) { 553 for (int i = 0; i < testCaseWeightPair.weight; i++) { 554 testCases.add(testCaseWeightPair.testCase); 555 } 556 } 557 558 shuffle(testCases); 559 560 this.testCases = Iterators.cycle(testCases); 561 } 562 nextTestCase()563 TestCases nextTestCase() { 564 return testCases.next(); 565 } 566 } 567 } 568 569 /** 570 * Service that exports the QPS metrics of the stress test. 571 */ 572 private class MetricsServiceImpl extends MetricsServiceGrpc.MetricsServiceImplBase { 573 574 @Override getAllGauges(Metrics.EmptyMessage request, StreamObserver<Metrics.GaugeResponse> responseObserver)575 public void getAllGauges(Metrics.EmptyMessage request, 576 StreamObserver<Metrics.GaugeResponse> responseObserver) { 577 for (Metrics.GaugeResponse gauge : gauges.values()) { 578 responseObserver.onNext(gauge); 579 } 580 responseObserver.onCompleted(); 581 } 582 583 @Override getGauge(Metrics.GaugeRequest request, StreamObserver<Metrics.GaugeResponse> responseObserver)584 public void getGauge(Metrics.GaugeRequest request, 585 StreamObserver<Metrics.GaugeResponse> responseObserver) { 586 String gaugeName = request.getName(); 587 Metrics.GaugeResponse gauge = gauges.get(gaugeName); 588 if (gauge != null) { 589 responseObserver.onNext(gauge); 590 responseObserver.onCompleted(); 591 } else { 592 responseObserver.onError(new StatusException(Status.NOT_FOUND)); 593 } 594 } 595 } 596 597 @VisibleForTesting 598 static class TestCaseWeightPair { 599 final TestCases testCase; 600 final int weight; 601 TestCaseWeightPair(TestCases testCase, int weight)602 TestCaseWeightPair(TestCases testCase, int weight) { 603 Preconditions.checkArgument(weight >= 0, "weight must be positive."); 604 this.testCase = Preconditions.checkNotNull(testCase, "testCase"); 605 this.weight = weight; 606 } 607 608 @Override equals(Object other)609 public boolean equals(Object other) { 610 if (!(other instanceof TestCaseWeightPair)) { 611 return false; 612 } 613 TestCaseWeightPair that = (TestCaseWeightPair) other; 614 return testCase.equals(that.testCase) && weight == that.weight; 615 } 616 617 @Override hashCode()618 public int hashCode() { 619 return Objects.hashCode(testCase, weight); 620 } 621 } 622 623 @VisibleForTesting addresses()624 List<InetSocketAddress> addresses() { 625 return Collections.unmodifiableList(addresses); 626 } 627 628 @VisibleForTesting serverHostOverride()629 String serverHostOverride() { 630 return serverHostOverride; 631 } 632 633 @VisibleForTesting useTls()634 boolean useTls() { 635 return useTls; 636 } 637 638 @VisibleForTesting useTestCa()639 boolean useTestCa() { 640 return useTestCa; 641 } 642 643 @VisibleForTesting testCaseWeightPairs()644 List<TestCaseWeightPair> testCaseWeightPairs() { 645 return testCaseWeightPairs; 646 } 647 648 @VisibleForTesting durationSecs()649 int durationSecs() { 650 return durationSecs; 651 } 652 653 @VisibleForTesting channelsPerServer()654 int channelsPerServer() { 655 return channelsPerServer; 656 } 657 658 @VisibleForTesting stubsPerChannel()659 int stubsPerChannel() { 660 return stubsPerChannel; 661 } 662 663 @VisibleForTesting metricsPort()664 int metricsPort() { 665 return metricsPort; 666 } 667 } 668