1 #region Copyright notice and license 2 3 // Copyright 2015-2016 gRPC authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 #endregion 18 19 using System; 20 using System.Collections.Generic; 21 using System.IO; 22 using System.Linq; 23 using System.Threading; 24 using System.Threading.Tasks; 25 using Google.Protobuf; 26 using Grpc.Core; 27 using Grpc.Core.Utils; 28 using Grpc.Testing; 29 using NUnit.Framework; 30 31 namespace Grpc.IntegrationTesting 32 { 33 /// <summary> 34 /// Test SSL credentials where server authenticates client 35 /// and client authenticates the server. 36 /// </summary> 37 public class SslCredentialsTest 38 { 39 const string Host = "localhost"; 40 Server server; 41 Channel channel; 42 TestService.TestServiceClient client; 43 44 [OneTimeSetUp] Init()45 public void Init() 46 { 47 var rootCert = File.ReadAllText(TestCredentials.ClientCertAuthorityPath); 48 var keyCertPair = new KeyCertificatePair( 49 File.ReadAllText(TestCredentials.ServerCertChainPath), 50 File.ReadAllText(TestCredentials.ServerPrivateKeyPath)); 51 52 var serverCredentials = new SslServerCredentials(new[] { keyCertPair }, rootCert, true); 53 var clientCredentials = new SslCredentials(rootCert, keyCertPair); 54 55 // Disable SO_REUSEPORT to prevent https://github.com/grpc/grpc/issues/10755 56 server = new Server(new[] { new ChannelOption(ChannelOptions.SoReuseport, 0) }) 57 { 58 Services = { TestService.BindService(new SslCredentialsTestServiceImpl()) }, 59 Ports = { { Host, ServerPort.PickUnused, serverCredentials } } 60 }; 61 server.Start(); 62 63 var options = new List<ChannelOption> 64 { 65 new ChannelOption(ChannelOptions.SslTargetNameOverride, TestCredentials.DefaultHostOverride) 66 }; 67 68 channel = new Channel(Host, server.Ports.Single().BoundPort, clientCredentials, options); 69 client = new TestService.TestServiceClient(channel); 70 } 71 72 [OneTimeTearDown] Cleanup()73 public void Cleanup() 74 { 75 channel.ShutdownAsync().Wait(); 76 server.ShutdownAsync().Wait(); 77 } 78 79 [Test] AuthenticatedClientAndServer()80 public void AuthenticatedClientAndServer() 81 { 82 var response = client.UnaryCall(new SimpleRequest { ResponseSize = 10 }); 83 Assert.AreEqual(10, response.Payload.Body.Length); 84 } 85 86 [Test] AuthContextIsPopulated()87 public async Task AuthContextIsPopulated() 88 { 89 var call = client.StreamingInputCall(); 90 await call.RequestStream.CompleteAsync(); 91 var response = await call.ResponseAsync; 92 Assert.AreEqual(12345, response.AggregatedPayloadSize); 93 } 94 95 private class SslCredentialsTestServiceImpl : TestService.TestServiceBase 96 { UnaryCall(SimpleRequest request, ServerCallContext context)97 public override Task<SimpleResponse> UnaryCall(SimpleRequest request, ServerCallContext context) 98 { 99 return Task.FromResult(new SimpleResponse { Payload = CreateZerosPayload(request.ResponseSize) }); 100 } 101 StreamingInputCall(IAsyncStreamReader<StreamingInputCallRequest> requestStream, ServerCallContext context)102 public override async Task<StreamingInputCallResponse> StreamingInputCall(IAsyncStreamReader<StreamingInputCallRequest> requestStream, ServerCallContext context) 103 { 104 var authContext = context.AuthContext; 105 await requestStream.ForEachAsync(request => TaskUtils.CompletedTask); 106 107 Assert.IsTrue(authContext.IsPeerAuthenticated); 108 Assert.AreEqual("x509_subject_alternative_name", authContext.PeerIdentityPropertyName); 109 Assert.IsTrue(authContext.PeerIdentity.Count() > 0); 110 Assert.AreEqual("ssl", authContext.FindPropertiesByName("transport_security_type").First().Value); 111 112 return new StreamingInputCallResponse { AggregatedPayloadSize = 12345 }; 113 } 114 CreateZerosPayload(int size)115 private static Payload CreateZerosPayload(int size) 116 { 117 return new Payload { Body = ByteString.CopyFrom(new byte[size]) }; 118 } 119 } 120 } 121 } 122