1 #region Copyright notice and license 2 3 // Copyright 2015-2016 gRPC authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 #endregion 18 19 using System; 20 using System.Collections.Generic; 21 using System.Diagnostics; 22 using System.Linq; 23 using System.Threading; 24 using System.Threading.Tasks; 25 26 using CommandLine; 27 using CommandLine.Text; 28 using Grpc.Core; 29 using Grpc.Core.Logging; 30 using Grpc.Core.Utils; 31 using Grpc.Testing; 32 33 namespace Grpc.IntegrationTesting 34 { 35 public class StressTestClient 36 { 37 static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<StressTestClient>(); 38 const double SecondsToNanos = 1e9; 39 40 private class ClientOptions 41 { 42 [Option("server_addresses", Default = "localhost:8080")] 43 public string ServerAddresses { get; set; } 44 45 [Option("test_cases", Default = "large_unary:100")] 46 public string TestCases { get; set; } 47 48 [Option("test_duration_secs", Default = -1)] 49 public int TestDurationSecs { get; set; } 50 51 [Option("num_channels_per_server", Default = 1)] 52 public int NumChannelsPerServer { get; set; } 53 54 [Option("num_stubs_per_channel", Default = 1)] 55 public int NumStubsPerChannel { get; set; } 56 57 [Option("metrics_port", Default = 8081)] 58 public int MetricsPort { get; set; } 59 } 60 61 ClientOptions options; 62 List<string> serverAddresses; 63 Dictionary<string, int> weightedTestCases; 64 WeightedRandomGenerator testCaseGenerator; 65 66 // cancellation will be emitted once test_duration_secs has elapsed. 67 CancellationTokenSource finishedTokenSource = new CancellationTokenSource(); 68 Histogram histogram = new Histogram(0.01, 60 * SecondsToNanos); 69 StressTestClient(ClientOptions options, List<string> serverAddresses, Dictionary<string, int> weightedTestCases)70 private StressTestClient(ClientOptions options, List<string> serverAddresses, Dictionary<string, int> weightedTestCases) 71 { 72 this.options = options; 73 this.serverAddresses = serverAddresses; 74 this.weightedTestCases = weightedTestCases; 75 this.testCaseGenerator = new WeightedRandomGenerator(this.weightedTestCases); 76 } 77 Run(string[] args)78 public static void Run(string[] args) 79 { 80 GrpcEnvironment.SetLogger(new ConsoleLogger()); 81 var parserResult = Parser.Default.ParseArguments<ClientOptions>(args) 82 .WithNotParsed((x) => Environment.Exit(1)) 83 .WithParsed(options => { 84 GrpcPreconditions.CheckArgument(options.NumChannelsPerServer > 0); 85 GrpcPreconditions.CheckArgument(options.NumStubsPerChannel > 0); 86 87 var serverAddresses = options.ServerAddresses.Split(','); 88 GrpcPreconditions.CheckArgument(serverAddresses.Length > 0, "You need to provide at least one server address"); 89 90 var testCases = ParseWeightedTestCases(options.TestCases); 91 GrpcPreconditions.CheckArgument(testCases.Count > 0, "You need to provide at least one test case"); 92 93 var interopClient = new StressTestClient(options, serverAddresses.ToList(), testCases); 94 interopClient.Run().Wait(); 95 }); 96 } 97 Run()98 async Task Run() 99 { 100 var metricsServer = new Server() 101 { 102 Services = { MetricsService.BindService(new MetricsServiceImpl(histogram)) }, 103 Ports = { { "[::]", options.MetricsPort, ServerCredentials.Insecure } } 104 }; 105 metricsServer.Start(); 106 107 if (options.TestDurationSecs >= 0) 108 { 109 finishedTokenSource.CancelAfter(TimeSpan.FromSeconds(options.TestDurationSecs)); 110 } 111 112 var tasks = new List<Task>(); 113 var channels = new List<Channel>(); 114 foreach (var serverAddress in serverAddresses) 115 { 116 for (int i = 0; i < options.NumChannelsPerServer; i++) 117 { 118 var channel = new Channel(serverAddress, ChannelCredentials.Insecure); 119 channels.Add(channel); 120 for (int j = 0; j < options.NumStubsPerChannel; j++) 121 { 122 var client = new TestService.TestServiceClient(channel); 123 var task = Task.Factory.StartNew(() => RunBodyAsync(client).GetAwaiter().GetResult(), 124 TaskCreationOptions.LongRunning); 125 tasks.Add(task); 126 } 127 } 128 } 129 await Task.WhenAll(tasks); 130 131 foreach (var channel in channels) 132 { 133 await channel.ShutdownAsync(); 134 } 135 136 await metricsServer.ShutdownAsync(); 137 } 138 RunBodyAsync(TestService.TestServiceClient client)139 async Task RunBodyAsync(TestService.TestServiceClient client) 140 { 141 Logger.Info("Starting stress test client thread."); 142 while (!finishedTokenSource.Token.IsCancellationRequested) 143 { 144 var testCase = testCaseGenerator.GetNext(); 145 146 var stopwatch = Stopwatch.StartNew(); 147 148 await RunTestCaseAsync(client, testCase); 149 150 stopwatch.Stop(); 151 histogram.AddObservation(stopwatch.Elapsed.TotalSeconds * SecondsToNanos); 152 } 153 Logger.Info("Stress test client thread finished."); 154 } 155 RunTestCaseAsync(TestService.TestServiceClient client, string testCase)156 async Task RunTestCaseAsync(TestService.TestServiceClient client, string testCase) 157 { 158 switch (testCase) 159 { 160 case "empty_unary": 161 InteropClient.RunEmptyUnary(client); 162 break; 163 case "large_unary": 164 InteropClient.RunLargeUnary(client); 165 break; 166 case "client_streaming": 167 await InteropClient.RunClientStreamingAsync(client); 168 break; 169 case "server_streaming": 170 await InteropClient.RunServerStreamingAsync(client); 171 break; 172 case "ping_pong": 173 await InteropClient.RunPingPongAsync(client); 174 break; 175 case "empty_stream": 176 await InteropClient.RunEmptyStreamAsync(client); 177 break; 178 case "cancel_after_begin": 179 await InteropClient.RunCancelAfterBeginAsync(client); 180 break; 181 case "cancel_after_first_response": 182 await InteropClient.RunCancelAfterFirstResponseAsync(client); 183 break; 184 case "timeout_on_sleeping_server": 185 await InteropClient.RunTimeoutOnSleepingServerAsync(client); 186 break; 187 case "custom_metadata": 188 await InteropClient.RunCustomMetadataAsync(client); 189 break; 190 case "status_code_and_message": 191 await InteropClient.RunStatusCodeAndMessageAsync(client); 192 break; 193 default: 194 throw new ArgumentException("Unsupported test case " + testCase); 195 } 196 } 197 ParseWeightedTestCases(string weightedTestCases)198 static Dictionary<string, int> ParseWeightedTestCases(string weightedTestCases) 199 { 200 var result = new Dictionary<string, int>(); 201 foreach (var weightedTestCase in weightedTestCases.Split(',')) 202 { 203 var parts = weightedTestCase.Split(new char[] {':'}, 2); 204 GrpcPreconditions.CheckArgument(parts.Length == 2, "Malformed test_cases option."); 205 result.Add(parts[0], int.Parse(parts[1])); 206 } 207 return result; 208 } 209 210 class WeightedRandomGenerator 211 { 212 readonly Random random = new Random(); 213 readonly List<Tuple<int, string>> cumulativeSums; 214 readonly int weightSum; 215 WeightedRandomGenerator(Dictionary<string, int> weightedItems)216 public WeightedRandomGenerator(Dictionary<string, int> weightedItems) 217 { 218 cumulativeSums = new List<Tuple<int, string>>(); 219 weightSum = 0; 220 foreach (var entry in weightedItems) 221 { 222 weightSum += entry.Value; 223 cumulativeSums.Add(Tuple.Create(weightSum, entry.Key)); 224 } 225 } 226 GetNext()227 public string GetNext() 228 { 229 int rand = random.Next(weightSum); 230 foreach (var entry in cumulativeSums) 231 { 232 if (rand < entry.Item1) 233 { 234 return entry.Item2; 235 } 236 } 237 throw new InvalidOperationException("GetNext() failed."); 238 } 239 } 240 241 class MetricsServiceImpl : MetricsService.MetricsServiceBase 242 { 243 const string GaugeName = "csharp_overall_qps"; 244 245 readonly Histogram histogram; 246 readonly TimeStats timeStats = new TimeStats(); 247 MetricsServiceImpl(Histogram histogram)248 public MetricsServiceImpl(Histogram histogram) 249 { 250 this.histogram = histogram; 251 } 252 GetGauge(GaugeRequest request, ServerCallContext context)253 public override Task<GaugeResponse> GetGauge(GaugeRequest request, ServerCallContext context) 254 { 255 if (request.Name == GaugeName) 256 { 257 long qps = GetQpsAndReset(); 258 259 return Task.FromResult(new GaugeResponse 260 { 261 Name = GaugeName, 262 LongValue = qps 263 }); 264 } 265 throw new RpcException(new Status(StatusCode.InvalidArgument, "Gauge does not exist")); 266 } 267 GetAllGauges(EmptyMessage request, IServerStreamWriter<GaugeResponse> responseStream, ServerCallContext context)268 public override async Task GetAllGauges(EmptyMessage request, IServerStreamWriter<GaugeResponse> responseStream, ServerCallContext context) 269 { 270 long qps = GetQpsAndReset(); 271 272 var response = new GaugeResponse 273 { 274 Name = GaugeName, 275 LongValue = qps 276 }; 277 await responseStream.WriteAsync(response); 278 } 279 GetQpsAndReset()280 long GetQpsAndReset() 281 { 282 var snapshot = histogram.GetSnapshot(true); 283 var timeSnapshot = timeStats.GetSnapshot(true); 284 285 return (long) (snapshot.Count / timeSnapshot.WallClockTime.TotalSeconds); 286 } 287 } 288 } 289 } 290