1 /** 2 * Copyright (C) 2011 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.inject.servlet; 18 19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST; 20 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest; 21 import static com.google.inject.servlet.ServletTestUtils.newNoOpFilterChain; 22 import static org.easymock.EasyMock.expect; 23 import static org.easymock.EasyMock.expectLastCall; 24 25 import com.google.inject.Guice; 26 import com.google.inject.Injector; 27 import com.google.inject.Key; 28 import com.google.inject.Singleton; 29 30 import junit.framework.TestCase; 31 32 import org.easymock.EasyMock; 33 import org.easymock.IMocksControl; 34 35 import java.io.IOException; 36 import java.util.ArrayList; 37 import java.util.List; 38 import java.util.concurrent.atomic.AtomicInteger; 39 40 import javax.servlet.Filter; 41 import javax.servlet.FilterChain; 42 import javax.servlet.FilterConfig; 43 import javax.servlet.ServletException; 44 import javax.servlet.ServletRequest; 45 import javax.servlet.ServletResponse; 46 import javax.servlet.http.HttpServlet; 47 import javax.servlet.http.HttpServletRequest; 48 import javax.servlet.http.HttpServletResponse; 49 50 /** 51 * 52 * This tests that filter stage of the pipeline dispatches 53 * correctly to guice-managed filters. 54 * 55 * WARNING(dhanji): Non-parallelizable test =( 56 * 57 * @author dhanji@gmail.com (Dhanji R. Prasanna) 58 */ 59 public class FilterDispatchIntegrationTest extends TestCase { 60 private static int inits, doFilters, destroys; 61 62 private IMocksControl control; 63 64 @Override setUp()65 public final void setUp() { 66 inits = 0; 67 doFilters = 0; 68 destroys = 0; 69 control = EasyMock.createControl(); 70 GuiceFilter.reset(); 71 } 72 73 testDispatchRequestToManagedPipeline()74 public final void testDispatchRequestToManagedPipeline() throws ServletException, IOException { 75 final Injector injector = Guice.createInjector(new ServletModule() { 76 77 @Override 78 protected void configureServlets() { 79 filter("/*").through(TestFilter.class); 80 filter("*.html").through(TestFilter.class); 81 filter("/*").through(Key.get(TestFilter.class)); 82 83 // These filters should never fire 84 filter("/index/*").through(Key.get(TestFilter.class)); 85 filter("*.jsp").through(Key.get(TestFilter.class)); 86 87 // Bind a servlet 88 serve("*.html").with(TestServlet.class); 89 } 90 }); 91 92 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 93 pipeline.initPipeline(null); 94 95 // create ourselves a mock request with test URI 96 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 97 98 expect(requestMock.getRequestURI()) 99 .andReturn("/index.html") 100 .anyTimes(); 101 expect(requestMock.getContextPath()) 102 .andReturn("") 103 .anyTimes(); 104 105 requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true); 106 requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST); 107 108 HttpServletResponse responseMock = control.createMock(HttpServletResponse.class); 109 expect(responseMock.isCommitted()) 110 .andReturn(false) 111 .anyTimes(); 112 responseMock.resetBuffer(); 113 expectLastCall().anyTimes(); 114 115 FilterChain filterChain = control.createMock(FilterChain.class); 116 117 //dispatch request 118 control.replay(); 119 pipeline.dispatch(requestMock, responseMock, filterChain); 120 pipeline.destroyPipeline(); 121 control.verify(); 122 123 TestServlet servlet = injector.getInstance(TestServlet.class); 124 assertEquals(2, servlet.processedUris.size()); 125 assertTrue(servlet.processedUris.contains("/index.html")); 126 assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO)); 127 128 assertTrue("lifecycle states did not" 129 + " fire correct number of times-- inits: " + inits + "; dos: " + doFilters 130 + "; destroys: " + destroys, inits == 1 && doFilters == 3 && destroys == 1); 131 } 132 testDispatchThatNoFiltersFire()133 public final void testDispatchThatNoFiltersFire() throws ServletException, IOException { 134 final Injector injector = Guice.createInjector(new ServletModule() { 135 136 @Override 137 protected void configureServlets() { 138 filter("/public/*").through(TestFilter.class); 139 filter("*.html").through(TestFilter.class); 140 filter("*.xml").through(Key.get(TestFilter.class)); 141 142 // These filters should never fire 143 filter("/index/*").through(Key.get(TestFilter.class)); 144 filter("*.jsp").through(Key.get(TestFilter.class)); 145 } 146 }); 147 148 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 149 pipeline.initPipeline(null); 150 151 //create ourselves a mock request with test URI 152 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 153 154 expect(requestMock.getRequestURI()) 155 .andReturn("/index.xhtml") 156 .anyTimes(); 157 expect(requestMock.getContextPath()) 158 .andReturn("") 159 .anyTimes(); 160 161 //dispatch request 162 FilterChain filterChain = control.createMock(FilterChain.class); 163 filterChain.doFilter(requestMock, null); 164 control.replay(); 165 pipeline.dispatch(requestMock, null, filterChain); 166 pipeline.destroyPipeline(); 167 control.verify(); 168 169 assertTrue("lifecycle states did not " 170 + "fire correct number of times-- inits: " + inits + "; dos: " + doFilters 171 + "; destroys: " + destroys, 172 inits == 1 && doFilters == 0 && destroys == 1); 173 } 174 testDispatchFilterPipelineWithRegexMatching()175 public final void testDispatchFilterPipelineWithRegexMatching() throws ServletException, 176 IOException { 177 178 final Injector injector = Guice.createInjector(new ServletModule() { 179 180 @Override 181 protected void configureServlets() { 182 filterRegex("/[A-Za-z]*").through(TestFilter.class); 183 filterRegex("/index").through(TestFilter.class); 184 //these filters should never fire 185 filterRegex("\\w").through(Key.get(TestFilter.class)); 186 } 187 }); 188 189 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 190 pipeline.initPipeline(null); 191 192 //create ourselves a mock request with test URI 193 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 194 195 expect(requestMock.getRequestURI()) 196 .andReturn("/index") 197 .anyTimes(); 198 expect(requestMock.getContextPath()) 199 .andReturn("") 200 .anyTimes(); 201 202 // dispatch request 203 FilterChain filterChain = control.createMock(FilterChain.class); 204 filterChain.doFilter(requestMock, null); 205 control.replay(); 206 pipeline.dispatch(requestMock, null, filterChain); 207 pipeline.destroyPipeline(); 208 control.verify(); 209 210 assertTrue("lifecycle states did not fire " 211 + "correct number of times-- inits: " + inits + "; dos: " + doFilters 212 + "; destroys: " + destroys, 213 inits == 1 && doFilters == 2 && destroys == 1); 214 } 215 216 @Singleton 217 public static class TestFilter implements Filter { init(FilterConfig filterConfig)218 public void init(FilterConfig filterConfig) { 219 inits++; 220 } 221 doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)222 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, 223 FilterChain filterChain) throws IOException, ServletException { 224 doFilters++; 225 filterChain.doFilter(servletRequest, servletResponse); 226 } 227 destroy()228 public void destroy() { 229 destroys++; 230 } 231 } 232 233 @Singleton 234 public static class TestServlet extends HttpServlet { 235 public static final String FORWARD_FROM = "/index.html"; 236 public static final String FORWARD_TO = "/forwarded.html"; 237 public List<String> processedUris = new ArrayList<String>(); 238 239 @Override service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)240 protected void service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) 241 throws ServletException, IOException { 242 String requestUri = httpServletRequest.getRequestURI(); 243 processedUris.add(requestUri); 244 245 // If the client is requesting /index.html then we forward to /forwarded.html 246 if (FORWARD_FROM.equals(requestUri)) { 247 httpServletRequest.getRequestDispatcher(FORWARD_TO) 248 .forward(httpServletRequest, httpServletResponse); 249 } 250 } 251 252 @Override service(ServletRequest servletRequest, ServletResponse servletResponse)253 public void service(ServletRequest servletRequest, ServletResponse servletResponse) 254 throws ServletException, IOException { 255 service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse); 256 } 257 } 258 testFilterOrder()259 public void testFilterOrder() throws Exception { 260 AtomicInteger counter = new AtomicInteger(); 261 final CountFilter f1 = new CountFilter(counter); 262 final CountFilter f2 = new CountFilter(counter); 263 264 Injector injector = Guice.createInjector(new ServletModule() { 265 @Override 266 protected void configureServlets() { 267 filter("/").through(f1); 268 install(new ServletModule() { 269 @Override 270 protected void configureServlets() { 271 filter("/").through(f2); 272 } 273 }); 274 } 275 }); 276 277 HttpServletRequest request = newFakeHttpServletRequest(); 278 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 279 pipeline.initPipeline(null); 280 pipeline.dispatch(request, null, newNoOpFilterChain()); 281 assertEquals(0, f1.calledAt); 282 assertEquals(1, f2.calledAt); 283 } 284 285 /** A filter that keeps count of when it was called by increment a counter. */ 286 private static class CountFilter implements Filter { 287 private final AtomicInteger counter; 288 private int calledAt = -1; 289 CountFilter(AtomicInteger counter)290 public CountFilter(AtomicInteger counter) { 291 this.counter = counter; 292 } 293 destroy()294 public void destroy() { 295 } 296 doFilter(ServletRequest request, ServletResponse response, FilterChain chain)297 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) 298 throws ServletException, IOException { 299 if (calledAt != -1) { 300 fail("not expecting to be called twice"); 301 } 302 calledAt = counter.getAndIncrement(); 303 chain.doFilter(request, response); 304 } 305 init(FilterConfig filterConfig)306 public void init(FilterConfig filterConfig) {} 307 } 308 testFilterExceptionPrunesStack()309 public final void testFilterExceptionPrunesStack() throws Exception { 310 Injector injector = Guice.createInjector(new ServletModule() { 311 @Override 312 protected void configureServlets() { 313 filter("/").through(TestFilter.class); 314 filter("/nothing").through(TestFilter.class); 315 filter("/").through(ThrowingFilter.class); 316 } 317 }); 318 319 HttpServletRequest request = newFakeHttpServletRequest(); 320 FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 321 pipeline.initPipeline(null); 322 try { 323 pipeline.dispatch(request, null, null); 324 fail("expected exception"); 325 } catch(ServletException ex) { 326 for (StackTraceElement element : ex.getStackTrace()) { 327 String className = element.getClassName(); 328 assertTrue("was: " + element, 329 !className.equals(FilterChainInvocation.class.getName()) 330 && !className.equals(FilterDefinition.class.getName())); 331 } 332 } 333 } 334 testServletExceptionPrunesStack()335 public final void testServletExceptionPrunesStack() throws Exception { 336 Injector injector = Guice.createInjector(new ServletModule() { 337 @Override 338 protected void configureServlets() { 339 filter("/").through(TestFilter.class); 340 filter("/nothing").through(TestFilter.class); 341 serve("/").with(ThrowingServlet.class); 342 } 343 }); 344 345 HttpServletRequest request = newFakeHttpServletRequest(); 346 FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 347 pipeline.initPipeline(null); 348 try { 349 pipeline.dispatch(request, null, null); 350 fail("expected exception"); 351 } catch(ServletException ex) { 352 for (StackTraceElement element : ex.getStackTrace()) { 353 String className = element.getClassName(); 354 assertTrue("was: " + element, 355 !className.equals(FilterChainInvocation.class.getName()) 356 && !className.equals(FilterDefinition.class.getName())); 357 } 358 } 359 } 360 361 @Singleton 362 private static class ThrowingServlet extends HttpServlet { 363 364 @Override service(HttpServletRequest req, HttpServletResponse resp)365 protected void service(HttpServletRequest req, HttpServletResponse resp) 366 throws ServletException { 367 throw new ServletException("failure!"); 368 } 369 370 } 371 372 373 @Singleton 374 private static class ThrowingFilter implements Filter { 375 @Override destroy()376 public void destroy() { 377 } 378 379 @Override doFilter(ServletRequest request, ServletResponse response, FilterChain chain)380 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) 381 throws ServletException { 382 throw new ServletException("we failed!"); 383 } 384 385 @Override init(FilterConfig filterConfig)386 public void init(FilterConfig filterConfig) { 387 388 } 389 390 } 391 } 392