1 /* 2 * Copyright (C) 2011 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.inject.servlet; 18 19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST; 20 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest; 21 import static com.google.inject.servlet.ServletTestUtils.newNoOpFilterChain; 22 import static org.easymock.EasyMock.expect; 23 import static org.easymock.EasyMock.expectLastCall; 24 25 import com.google.inject.Guice; 26 import com.google.inject.Injector; 27 import com.google.inject.Key; 28 import com.google.inject.Singleton; 29 import java.io.IOException; 30 import java.util.ArrayList; 31 import java.util.List; 32 import java.util.concurrent.atomic.AtomicInteger; 33 import javax.servlet.Filter; 34 import javax.servlet.FilterChain; 35 import javax.servlet.FilterConfig; 36 import javax.servlet.ServletException; 37 import javax.servlet.ServletRequest; 38 import javax.servlet.ServletResponse; 39 import javax.servlet.http.HttpServlet; 40 import javax.servlet.http.HttpServletRequest; 41 import javax.servlet.http.HttpServletResponse; 42 import junit.framework.TestCase; 43 import org.easymock.EasyMock; 44 import org.easymock.IMocksControl; 45 46 /** 47 * This tests that filter stage of the pipeline dispatches correctly to guice-managed filters. 48 * 49 * <p>WARNING(dhanji): Non-parallelizable test =( 50 * 51 * @author dhanji@gmail.com (Dhanji R. Prasanna) 52 */ 53 public class FilterDispatchIntegrationTest extends TestCase { 54 private static int inits, doFilters, destroys; 55 56 private IMocksControl control; 57 58 @Override setUp()59 public final void setUp() { 60 inits = 0; 61 doFilters = 0; 62 destroys = 0; 63 control = EasyMock.createControl(); 64 GuiceFilter.reset(); 65 } 66 testDispatchRequestToManagedPipeline()67 public final void testDispatchRequestToManagedPipeline() throws ServletException, IOException { 68 final Injector injector = 69 Guice.createInjector( 70 new ServletModule() { 71 72 @Override 73 protected void configureServlets() { 74 filter("/*").through(TestFilter.class); 75 filter("*.html").through(TestFilter.class); 76 filter("/*").through(Key.get(TestFilter.class)); 77 78 // These filters should never fire 79 filter("/index/*").through(Key.get(TestFilter.class)); 80 filter("*.jsp").through(Key.get(TestFilter.class)); 81 82 // Bind a servlet 83 serve("*.html").with(TestServlet.class); 84 } 85 }); 86 87 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 88 pipeline.initPipeline(null); 89 90 // create ourselves a mock request with test URI 91 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 92 93 expect(requestMock.getRequestURI()).andReturn("/index.html").anyTimes(); 94 expect(requestMock.getContextPath()).andReturn("").anyTimes(); 95 96 requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true); 97 requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST); 98 99 HttpServletResponse responseMock = control.createMock(HttpServletResponse.class); 100 expect(responseMock.isCommitted()).andReturn(false).anyTimes(); 101 responseMock.resetBuffer(); 102 expectLastCall().anyTimes(); 103 104 FilterChain filterChain = control.createMock(FilterChain.class); 105 106 //dispatch request 107 control.replay(); 108 pipeline.dispatch(requestMock, responseMock, filterChain); 109 pipeline.destroyPipeline(); 110 control.verify(); 111 112 TestServlet servlet = injector.getInstance(TestServlet.class); 113 assertEquals(2, servlet.processedUris.size()); 114 assertTrue(servlet.processedUris.contains("/index.html")); 115 assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO)); 116 117 assertTrue( 118 "lifecycle states did not" 119 + " fire correct number of times-- inits: " 120 + inits 121 + "; dos: " 122 + doFilters 123 + "; destroys: " 124 + destroys, 125 inits == 1 && doFilters == 3 && destroys == 1); 126 } 127 testDispatchThatNoFiltersFire()128 public final void testDispatchThatNoFiltersFire() throws ServletException, IOException { 129 final Injector injector = 130 Guice.createInjector( 131 new ServletModule() { 132 133 @Override 134 protected void configureServlets() { 135 filter("/public/*").through(TestFilter.class); 136 filter("*.html").through(TestFilter.class); 137 filter("*.xml").through(Key.get(TestFilter.class)); 138 139 // These filters should never fire 140 filter("/index/*").through(Key.get(TestFilter.class)); 141 filter("*.jsp").through(Key.get(TestFilter.class)); 142 } 143 }); 144 145 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 146 pipeline.initPipeline(null); 147 148 //create ourselves a mock request with test URI 149 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 150 151 expect(requestMock.getRequestURI()).andReturn("/index.xhtml").anyTimes(); 152 expect(requestMock.getContextPath()).andReturn("").anyTimes(); 153 154 //dispatch request 155 FilterChain filterChain = control.createMock(FilterChain.class); 156 filterChain.doFilter(requestMock, null); 157 control.replay(); 158 pipeline.dispatch(requestMock, null, filterChain); 159 pipeline.destroyPipeline(); 160 control.verify(); 161 162 assertTrue( 163 "lifecycle states did not " 164 + "fire correct number of times-- inits: " 165 + inits 166 + "; dos: " 167 + doFilters 168 + "; destroys: " 169 + destroys, 170 inits == 1 && doFilters == 0 && destroys == 1); 171 } 172 testDispatchFilterPipelineWithRegexMatching()173 public final void testDispatchFilterPipelineWithRegexMatching() 174 throws ServletException, IOException { 175 176 final Injector injector = 177 Guice.createInjector( 178 new ServletModule() { 179 180 @Override 181 protected void configureServlets() { 182 filterRegex("/[A-Za-z]*").through(TestFilter.class); 183 filterRegex("/index").through(TestFilter.class); 184 //these filters should never fire 185 filterRegex("\\w").through(Key.get(TestFilter.class)); 186 } 187 }); 188 189 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 190 pipeline.initPipeline(null); 191 192 //create ourselves a mock request with test URI 193 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 194 195 expect(requestMock.getRequestURI()).andReturn("/index").anyTimes(); 196 expect(requestMock.getContextPath()).andReturn("").anyTimes(); 197 198 // dispatch request 199 FilterChain filterChain = control.createMock(FilterChain.class); 200 filterChain.doFilter(requestMock, null); 201 control.replay(); 202 pipeline.dispatch(requestMock, null, filterChain); 203 pipeline.destroyPipeline(); 204 control.verify(); 205 206 assertTrue( 207 "lifecycle states did not fire " 208 + "correct number of times-- inits: " 209 + inits 210 + "; dos: " 211 + doFilters 212 + "; destroys: " 213 + destroys, 214 inits == 1 && doFilters == 2 && destroys == 1); 215 } 216 217 @Singleton 218 public static class TestFilter implements Filter { 219 @Override init(FilterConfig filterConfig)220 public void init(FilterConfig filterConfig) { 221 inits++; 222 } 223 224 @Override doFilter( ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)225 public void doFilter( 226 ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) 227 throws IOException, ServletException { 228 doFilters++; 229 filterChain.doFilter(servletRequest, servletResponse); 230 } 231 232 @Override destroy()233 public void destroy() { 234 destroys++; 235 } 236 } 237 testFilterBypass()238 public final void testFilterBypass() throws ServletException, IOException { 239 240 final Injector injector = 241 Guice.createInjector( 242 new ServletModule() { 243 @Override 244 protected void configureServlets() { 245 filter("/protected/*").through(TestFilter.class); 246 } 247 }); 248 249 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 250 pipeline.initPipeline(null); 251 assertEquals(1, inits); 252 253 runRequestForPath(pipeline, "/./protected/resource", true); 254 runRequestForPath(pipeline, "/protected/../resource", false); 255 runRequestForPath(pipeline, "/protected/../protected/resource", true); 256 257 assertEquals(0, destroys); 258 pipeline.destroyPipeline(); 259 assertEquals(1, destroys); 260 } 261 runRequestForPath(FilterPipeline pipeline, String value, boolean matches)262 private void runRequestForPath(FilterPipeline pipeline, String value, boolean matches) 263 throws IOException, ServletException { 264 assertEquals(0, doFilters); 265 //create ourselves a mock request with test URI 266 HttpServletRequest requestMock = control.createMock(HttpServletRequest.class); 267 expect(requestMock.getRequestURI()).andReturn(value).anyTimes(); 268 expect(requestMock.getContextPath()).andReturn("").anyTimes(); 269 // dispatch request 270 FilterChain filterChain = control.createMock(FilterChain.class); 271 filterChain.doFilter(requestMock, null); 272 control.replay(); 273 pipeline.dispatch(requestMock, null, filterChain); 274 control.verify(); 275 control.reset(); 276 if (matches) { 277 assertEquals("filter was not run", 1, doFilters); 278 doFilters = 0; 279 } else { 280 assertEquals("filter was run", 0, doFilters); 281 } 282 } 283 284 @Singleton 285 public static class TestServlet extends HttpServlet { 286 public static final String FORWARD_FROM = "/index.html"; 287 public static final String FORWARD_TO = "/forwarded.html"; 288 public List<String> processedUris = new ArrayList<>(); 289 290 @Override service( HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)291 protected void service( 292 HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) 293 throws ServletException, IOException { 294 String requestUri = httpServletRequest.getRequestURI(); 295 processedUris.add(requestUri); 296 297 // If the client is requesting /index.html then we forward to /forwarded.html 298 if (FORWARD_FROM.equals(requestUri)) { 299 httpServletRequest 300 .getRequestDispatcher(FORWARD_TO) 301 .forward(httpServletRequest, httpServletResponse); 302 } 303 } 304 305 @Override service(ServletRequest servletRequest, ServletResponse servletResponse)306 public void service(ServletRequest servletRequest, ServletResponse servletResponse) 307 throws ServletException, IOException { 308 service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse); 309 } 310 } 311 testFilterOrder()312 public void testFilterOrder() throws Exception { 313 AtomicInteger counter = new AtomicInteger(); 314 final CountFilter f1 = new CountFilter(counter); 315 final CountFilter f2 = new CountFilter(counter); 316 317 Injector injector = 318 Guice.createInjector( 319 new ServletModule() { 320 @Override 321 protected void configureServlets() { 322 filter("/").through(f1); 323 install( 324 new ServletModule() { 325 @Override 326 protected void configureServlets() { 327 filter("/").through(f2); 328 } 329 }); 330 } 331 }); 332 333 HttpServletRequest request = newFakeHttpServletRequest(); 334 final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 335 pipeline.initPipeline(null); 336 pipeline.dispatch(request, null, newNoOpFilterChain()); 337 assertEquals(0, f1.calledAt); 338 assertEquals(1, f2.calledAt); 339 } 340 341 /** A filter that keeps count of when it was called by increment a counter. */ 342 private static class CountFilter implements Filter { 343 private final AtomicInteger counter; 344 private int calledAt = -1; 345 CountFilter(AtomicInteger counter)346 public CountFilter(AtomicInteger counter) { 347 this.counter = counter; 348 } 349 350 @Override destroy()351 public void destroy() {} 352 353 @Override doFilter(ServletRequest request, ServletResponse response, FilterChain chain)354 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) 355 throws ServletException, IOException { 356 if (calledAt != -1) { 357 fail("not expecting to be called twice"); 358 } 359 calledAt = counter.getAndIncrement(); 360 chain.doFilter(request, response); 361 } 362 363 @Override init(FilterConfig filterConfig)364 public void init(FilterConfig filterConfig) {} 365 } 366 testFilterExceptionPrunesStack()367 public final void testFilterExceptionPrunesStack() throws Exception { 368 Injector injector = 369 Guice.createInjector( 370 new ServletModule() { 371 @Override 372 protected void configureServlets() { 373 filter("/").through(TestFilter.class); 374 filter("/nothing").through(TestFilter.class); 375 filter("/").through(ThrowingFilter.class); 376 } 377 }); 378 379 HttpServletRequest request = newFakeHttpServletRequest(); 380 FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 381 pipeline.initPipeline(null); 382 try { 383 pipeline.dispatch(request, null, null); 384 fail("expected exception"); 385 } catch (ServletException ex) { 386 for (StackTraceElement element : ex.getStackTrace()) { 387 String className = element.getClassName(); 388 assertTrue( 389 "was: " + element, 390 !className.equals(FilterChainInvocation.class.getName()) 391 && !className.equals(FilterDefinition.class.getName())); 392 } 393 } 394 } 395 testServletExceptionPrunesStack()396 public final void testServletExceptionPrunesStack() throws Exception { 397 Injector injector = 398 Guice.createInjector( 399 new ServletModule() { 400 @Override 401 protected void configureServlets() { 402 filter("/").through(TestFilter.class); 403 filter("/nothing").through(TestFilter.class); 404 serve("/").with(ThrowingServlet.class); 405 } 406 }); 407 408 HttpServletRequest request = newFakeHttpServletRequest(); 409 FilterPipeline pipeline = injector.getInstance(FilterPipeline.class); 410 pipeline.initPipeline(null); 411 try { 412 pipeline.dispatch(request, null, null); 413 fail("expected exception"); 414 } catch (ServletException ex) { 415 for (StackTraceElement element : ex.getStackTrace()) { 416 String className = element.getClassName(); 417 assertTrue( 418 "was: " + element, 419 !className.equals(FilterChainInvocation.class.getName()) 420 && !className.equals(FilterDefinition.class.getName())); 421 } 422 } 423 } 424 425 @Singleton 426 private static class ThrowingServlet extends HttpServlet { 427 428 @Override service(HttpServletRequest req, HttpServletResponse resp)429 protected void service(HttpServletRequest req, HttpServletResponse resp) 430 throws ServletException { 431 throw new ServletException("failure!"); 432 } 433 } 434 435 @Singleton 436 private static class ThrowingFilter implements Filter { 437 @Override destroy()438 public void destroy() {} 439 440 @Override doFilter(ServletRequest request, ServletResponse response, FilterChain chain)441 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) 442 throws ServletException { 443 throw new ServletException("we failed!"); 444 } 445 446 @Override init(FilterConfig filterConfig)447 public void init(FilterConfig filterConfig) {} 448 } 449 } 450