1 /* 2 * Copyright (C) 2008 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.inject.servlet; 18 19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST; 20 import static org.easymock.EasyMock.anyObject; 21 import static org.easymock.EasyMock.createMock; 22 import static org.easymock.EasyMock.eq; 23 import static org.easymock.EasyMock.expect; 24 import static org.easymock.EasyMock.expectLastCall; 25 import static org.easymock.EasyMock.replay; 26 import static org.easymock.EasyMock.verify; 27 28 import com.google.common.collect.ImmutableList; 29 import com.google.common.collect.Sets; 30 import com.google.inject.Binding; 31 import com.google.inject.Injector; 32 import com.google.inject.Key; 33 import com.google.inject.Provider; 34 import com.google.inject.TypeLiteral; 35 import com.google.inject.spi.BindingScopingVisitor; 36 import com.google.inject.util.Providers; 37 import java.io.IOException; 38 import java.util.ArrayList; 39 import java.util.Date; 40 import java.util.HashMap; 41 import java.util.List; 42 import java.util.UUID; 43 import javax.servlet.RequestDispatcher; 44 import javax.servlet.ServletException; 45 import javax.servlet.http.HttpServlet; 46 import javax.servlet.http.HttpServletRequest; 47 import javax.servlet.http.HttpServletResponse; 48 import junit.framework.TestCase; 49 50 /** 51 * Tests forwarding and inclusion (RequestDispatcher actions from the servlet spec). 52 * 53 * @author Dhanji R. Prasanna (dhanji@gmail com) 54 */ 55 public class ServletPipelineRequestDispatcherTest extends TestCase { 56 private static final Key<HttpServlet> HTTP_SERLVET_KEY = Key.get(HttpServlet.class); 57 private static final String A_KEY = "thinglyDEgintly" + new Date() + UUID.randomUUID(); 58 private static final String A_VALUE = 59 ServletPipelineRequestDispatcherTest.class.toString() + new Date() + UUID.randomUUID(); 60 testIncludeManagedServlet()61 public final void testIncludeManagedServlet() throws IOException, ServletException { 62 String pattern = "blah.html"; 63 final ServletDefinition servletDefinition = 64 new ServletDefinition( 65 Key.get(HttpServlet.class), 66 UriPatternType.get(UriPatternType.SERVLET, pattern), 67 new HashMap<String, String>(), 68 null); 69 70 final Injector injector = createMock(Injector.class); 71 final Binding binding = createMock(Binding.class); 72 final HttpServletRequest requestMock = createMock(HttpServletRequest.class); 73 74 expect(requestMock.getAttribute(A_KEY)).andReturn(A_VALUE); 75 76 requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true); 77 requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST); 78 79 final boolean[] run = new boolean[1]; 80 final HttpServlet mockServlet = 81 new HttpServlet() { 82 @Override 83 protected void service( 84 HttpServletRequest request, HttpServletResponse httpServletResponse) 85 throws ServletException, IOException { 86 run[0] = true; 87 88 final Object o = request.getAttribute(A_KEY); 89 assertEquals("Wrong attrib returned - " + o, A_VALUE, o); 90 } 91 }; 92 93 expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true); 94 expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding); 95 expect(injector.getInstance(HTTP_SERLVET_KEY)).andReturn(mockServlet); 96 97 final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class)); 98 99 Binding<ServletDefinition> mockBinding = createMock(Binding.class); 100 expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral()))) 101 .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding)); 102 Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition); 103 expect(mockBinding.getProvider()).andReturn(bindingProvider); 104 105 replay(injector, binding, requestMock, mockBinding); 106 107 // Have to init the Servlet before we can dispatch to it. 108 servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet()); 109 110 final RequestDispatcher dispatcher = 111 new ManagedServletPipeline(injector).getRequestDispatcher(pattern); 112 113 assertNotNull(dispatcher); 114 dispatcher.include(requestMock, createMock(HttpServletResponse.class)); 115 116 assertTrue("Include did not dispatch to our servlet!", run[0]); 117 118 verify(injector, requestMock, mockBinding); 119 } 120 testForwardToManagedServlet()121 public final void testForwardToManagedServlet() throws IOException, ServletException { 122 String pattern = "blah.html"; 123 final ServletDefinition servletDefinition = 124 new ServletDefinition( 125 Key.get(HttpServlet.class), 126 UriPatternType.get(UriPatternType.SERVLET, pattern), 127 new HashMap<String, String>(), 128 null); 129 130 final Injector injector = createMock(Injector.class); 131 final Binding binding = createMock(Binding.class); 132 final HttpServletRequest requestMock = createMock(HttpServletRequest.class); 133 final HttpServletResponse mockResponse = createMock(HttpServletResponse.class); 134 135 expect(requestMock.getAttribute(A_KEY)).andReturn(A_VALUE); 136 137 requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true); 138 requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST); 139 140 expect(mockResponse.isCommitted()).andReturn(false); 141 142 mockResponse.resetBuffer(); 143 expectLastCall().once(); 144 145 final List<String> paths = new ArrayList<>(); 146 final HttpServlet mockServlet = 147 new HttpServlet() { 148 @Override 149 protected void service( 150 HttpServletRequest request, HttpServletResponse httpServletResponse) 151 throws ServletException, IOException { 152 paths.add(request.getRequestURI()); 153 154 final Object o = request.getAttribute(A_KEY); 155 assertEquals("Wrong attrib returned - " + o, A_VALUE, o); 156 } 157 }; 158 159 expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true); 160 expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding); 161 162 expect(injector.getInstance(HTTP_SERLVET_KEY)).andReturn(mockServlet); 163 164 final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class)); 165 166 Binding<ServletDefinition> mockBinding = createMock(Binding.class); 167 expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral()))) 168 .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding)); 169 Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition); 170 expect(mockBinding.getProvider()).andReturn(bindingProvider); 171 172 replay(injector, binding, requestMock, mockResponse, mockBinding); 173 174 // Have to init the Servlet before we can dispatch to it. 175 servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet()); 176 177 final RequestDispatcher dispatcher = 178 new ManagedServletPipeline(injector).getRequestDispatcher(pattern); 179 180 assertNotNull(dispatcher); 181 dispatcher.forward(requestMock, mockResponse); 182 183 assertTrue("Include did not dispatch to our servlet!", paths.contains(pattern)); 184 185 verify(injector, requestMock, mockResponse, mockBinding); 186 } 187 testForwardToManagedServletFailureOnCommittedBuffer()188 public final void testForwardToManagedServletFailureOnCommittedBuffer() 189 throws IOException, ServletException { 190 IllegalStateException expected = null; 191 try { 192 forwardToManagedServletFailureOnCommittedBuffer(); 193 } catch (IllegalStateException ise) { 194 expected = ise; 195 } finally { 196 assertNotNull("Expected IllegalStateException was not thrown", expected); 197 } 198 } 199 forwardToManagedServletFailureOnCommittedBuffer()200 public final void forwardToManagedServletFailureOnCommittedBuffer() 201 throws IOException, ServletException { 202 String pattern = "blah.html"; 203 final ServletDefinition servletDefinition = 204 new ServletDefinition( 205 Key.get(HttpServlet.class), 206 UriPatternType.get(UriPatternType.SERVLET, pattern), 207 new HashMap<String, String>(), 208 null); 209 210 final Injector injector = createMock(Injector.class); 211 final Binding binding = createMock(Binding.class); 212 final HttpServletRequest mockRequest = createMock(HttpServletRequest.class); 213 final HttpServletResponse mockResponse = createMock(HttpServletResponse.class); 214 215 expect(mockResponse.isCommitted()).andReturn(true); 216 217 final HttpServlet mockServlet = 218 new HttpServlet() { 219 @Override 220 protected void service( 221 HttpServletRequest request, HttpServletResponse httpServletResponse) 222 throws ServletException, IOException { 223 224 final Object o = request.getAttribute(A_KEY); 225 assertEquals("Wrong attrib returned - " + o, A_VALUE, o); 226 } 227 }; 228 229 expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true); 230 expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding); 231 232 expect(injector.getInstance(Key.get(HttpServlet.class))).andReturn(mockServlet); 233 234 final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class)); 235 236 Binding<ServletDefinition> mockBinding = createMock(Binding.class); 237 expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral()))) 238 .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding)); 239 Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition); 240 expect(mockBinding.getProvider()).andReturn(bindingProvider); 241 242 replay(injector, binding, mockRequest, mockResponse, mockBinding); 243 244 // Have to init the Servlet before we can dispatch to it. 245 servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet()); 246 247 final RequestDispatcher dispatcher = 248 new ManagedServletPipeline(injector).getRequestDispatcher(pattern); 249 250 assertNotNull(dispatcher); 251 252 try { 253 dispatcher.forward(mockRequest, mockResponse); 254 } finally { 255 verify(injector, mockRequest, mockResponse, mockBinding); 256 } 257 } 258 testWrappedRequestUriAndUrlConsistency()259 public final void testWrappedRequestUriAndUrlConsistency() { 260 final HttpServletRequest mockRequest = createMock(HttpServletRequest.class); 261 expect(mockRequest.getScheme()).andReturn("http"); 262 expect(mockRequest.getServerName()).andReturn("the.server"); 263 expect(mockRequest.getServerPort()).andReturn(12345); 264 replay(mockRequest); 265 HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri"); 266 assertEquals("/new-uri", wrappedRequest.getRequestURI()); 267 assertEquals("http://the.server:12345/new-uri", wrappedRequest.getRequestURL().toString()); 268 } 269 testWrappedRequestUrlNegativePort()270 public final void testWrappedRequestUrlNegativePort() { 271 final HttpServletRequest mockRequest = createMock(HttpServletRequest.class); 272 expect(mockRequest.getScheme()).andReturn("http"); 273 expect(mockRequest.getServerName()).andReturn("the.server"); 274 expect(mockRequest.getServerPort()).andReturn(-1); 275 replay(mockRequest); 276 HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri"); 277 assertEquals("/new-uri", wrappedRequest.getRequestURI()); 278 assertEquals("http://the.server/new-uri", wrappedRequest.getRequestURL().toString()); 279 } 280 testWrappedRequestUrlDefaultPort()281 public final void testWrappedRequestUrlDefaultPort() { 282 final HttpServletRequest mockRequest = createMock(HttpServletRequest.class); 283 expect(mockRequest.getScheme()).andReturn("http"); 284 expect(mockRequest.getServerName()).andReturn("the.server"); 285 expect(mockRequest.getServerPort()).andReturn(80); 286 replay(mockRequest); 287 HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri"); 288 assertEquals("/new-uri", wrappedRequest.getRequestURI()); 289 assertEquals("http://the.server/new-uri", wrappedRequest.getRequestURL().toString()); 290 } 291 testWrappedRequestUrlDefaultHttpsPort()292 public final void testWrappedRequestUrlDefaultHttpsPort() { 293 final HttpServletRequest mockRequest = createMock(HttpServletRequest.class); 294 expect(mockRequest.getScheme()).andReturn("https"); 295 expect(mockRequest.getServerName()).andReturn("the.server"); 296 expect(mockRequest.getServerPort()).andReturn(443); 297 replay(mockRequest); 298 HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri"); 299 assertEquals("/new-uri", wrappedRequest.getRequestURI()); 300 assertEquals("https://the.server/new-uri", wrappedRequest.getRequestURL().toString()); 301 } 302 } 303