1 #region Copyright notice and license 2 3 // Copyright 2018 gRPC authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 #endregion 18 19 using System; 20 using System.Collections.Generic; 21 using System.Linq; 22 using System.Text; 23 using System.Threading; 24 using System.Threading.Tasks; 25 using Grpc.Core; 26 using Grpc.Core.Interceptors; 27 using Grpc.Core.Internal; 28 using Grpc.Core.Tests; 29 using Grpc.Core.Utils; 30 using NUnit.Framework; 31 32 namespace Grpc.Core.Interceptors.Tests 33 { 34 public class ServerInterceptorTest 35 { 36 const string Host = "127.0.0.1"; 37 38 [Test] AddRequestHeaderInServerInterceptor()39 public void AddRequestHeaderInServerInterceptor() 40 { 41 var helper = new MockServiceHelper(Host); 42 const string MetadataKey = "x-interceptor"; 43 const string MetadataValue = "hello world"; 44 var interceptor = new ServerCallContextInterceptor(ctx => ctx.RequestHeaders.Add(new Metadata.Entry(MetadataKey, MetadataValue))); 45 helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => 46 { 47 var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == MetadataKey)).Value; 48 Assert.AreEqual(interceptorHeader, MetadataValue); 49 return Task.FromResult("PASS"); 50 }); 51 helper.ServiceDefinition = helper.ServiceDefinition.Intercept(interceptor); 52 var server = helper.GetServer(); 53 server.Start(); 54 var channel = helper.GetChannel(); 55 Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); 56 } 57 58 [Test] VerifyInterceptorOrdering()59 public void VerifyInterceptorOrdering() 60 { 61 var helper = new MockServiceHelper(Host); 62 helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => 63 { 64 return Task.FromResult("PASS"); 65 }); 66 var stringBuilder = new StringBuilder(); 67 helper.ServiceDefinition = helper.ServiceDefinition 68 .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("A"))) 69 .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("B1")), 70 new ServerCallContextInterceptor(ctx => stringBuilder.Append("B2")), 71 new ServerCallContextInterceptor(ctx => stringBuilder.Append("B3"))) 72 .Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("C"))); 73 var server = helper.GetServer(); 74 server.Start(); 75 var channel = helper.GetChannel(); 76 Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); 77 Assert.AreEqual("CB1B2B3A", stringBuilder.ToString()); 78 } 79 80 [Test] UserStateVisibleToAllInterceptors()81 public void UserStateVisibleToAllInterceptors() 82 { 83 object key1 = new object(); 84 object value1 = new object(); 85 const string key2 = "Interceptor #2"; 86 const string value2 = "Important state"; 87 88 var interceptor1 = new ServerCallContextInterceptor(ctx => { 89 // state starts off empty 90 Assert.AreEqual(0, ctx.UserState.Count); 91 92 ctx.UserState.Add(key1, value1); 93 }); 94 95 var interceptor2 = new ServerCallContextInterceptor(ctx => { 96 // second interceptor can see state set by the first 97 bool found = ctx.UserState.TryGetValue(key1, out object storedValue1); 98 Assert.IsTrue(found); 99 Assert.AreEqual(value1, storedValue1); 100 101 ctx.UserState.Add(key2, value2); 102 }); 103 104 var helper = new MockServiceHelper(Host); 105 helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) => { 106 // call handler can see all the state 107 bool found = context.UserState.TryGetValue(key1, out object storedValue1); 108 Assert.IsTrue(found); 109 Assert.AreEqual(value1, storedValue1); 110 111 found = context.UserState.TryGetValue(key2, out object storedValue2); 112 Assert.IsTrue(found); 113 Assert.AreEqual(value2, storedValue2); 114 115 return Task.FromResult("PASS"); 116 }); 117 helper.ServiceDefinition = helper.ServiceDefinition 118 .Intercept(interceptor2) 119 .Intercept(interceptor1); 120 121 var server = helper.GetServer(); 122 server.Start(); 123 var channel = helper.GetChannel(); 124 Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); 125 } 126 127 [Test] CheckNullInterceptorRegistrationFails()128 public void CheckNullInterceptorRegistrationFails() 129 { 130 var helper = new MockServiceHelper(Host); 131 var sd = helper.ServiceDefinition; 132 Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor))); 133 Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{default(Interceptor)})); 134 Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ServerCallContextInterceptor(ctx=>{}), null})); 135 Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor[]))); 136 } 137 138 private class ServerCallContextInterceptor : Interceptor 139 { 140 readonly Action<ServerCallContext> interceptor; 141 ServerCallContextInterceptor(Action<ServerCallContext> interceptor)142 public ServerCallContextInterceptor(Action<ServerCallContext> interceptor) 143 { 144 GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor)); 145 this.interceptor = interceptor; 146 } 147 UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation)148 public override Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation) 149 { 150 interceptor(context); 151 return continuation(request, context); 152 } 153 ClientStreamingServerHandler(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation)154 public override Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation) 155 { 156 interceptor(context); 157 return continuation(requestStream, context); 158 } 159 ServerStreamingServerHandler(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation)160 public override Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation) 161 { 162 interceptor(context); 163 return continuation(request, responseStream, context); 164 } 165 DuplexStreamingServerHandler(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation)166 public override Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation) 167 { 168 interceptor(context); 169 return continuation(requestStream, responseStream, context); 170 } 171 } 172 } 173 } 174