• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2011 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.servlet;
18 
19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
20 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest;
21 import static com.google.inject.servlet.ServletTestUtils.newNoOpFilterChain;
22 import static org.easymock.EasyMock.expect;
23 import static org.easymock.EasyMock.expectLastCall;
24 
25 import com.google.inject.Guice;
26 import com.google.inject.Injector;
27 import com.google.inject.Key;
28 import com.google.inject.Singleton;
29 import java.io.IOException;
30 import java.util.ArrayList;
31 import java.util.List;
32 import java.util.concurrent.atomic.AtomicInteger;
33 import javax.servlet.Filter;
34 import javax.servlet.FilterChain;
35 import javax.servlet.FilterConfig;
36 import javax.servlet.ServletException;
37 import javax.servlet.ServletRequest;
38 import javax.servlet.ServletResponse;
39 import javax.servlet.http.HttpServlet;
40 import javax.servlet.http.HttpServletRequest;
41 import javax.servlet.http.HttpServletResponse;
42 import junit.framework.TestCase;
43 import org.easymock.EasyMock;
44 import org.easymock.IMocksControl;
45 
46 /**
47  * This tests that filter stage of the pipeline dispatches correctly to guice-managed filters.
48  *
49  * <p>WARNING(dhanji): Non-parallelizable test =(
50  *
51  * @author dhanji@gmail.com (Dhanji R. Prasanna)
52  */
53 public class FilterDispatchIntegrationTest extends TestCase {
54   private static int inits, doFilters, destroys;
55 
56   private IMocksControl control;
57 
58   @Override
setUp()59   public final void setUp() {
60     inits = 0;
61     doFilters = 0;
62     destroys = 0;
63     control = EasyMock.createControl();
64     GuiceFilter.reset();
65   }
66 
testDispatchRequestToManagedPipeline()67   public final void testDispatchRequestToManagedPipeline() throws ServletException, IOException {
68     final Injector injector =
69         Guice.createInjector(
70             new ServletModule() {
71 
72               @Override
73               protected void configureServlets() {
74                 filter("/*").through(TestFilter.class);
75                 filter("*.html").through(TestFilter.class);
76                 filter("/*").through(Key.get(TestFilter.class));
77 
78                 // These filters should never fire
79                 filter("/index/*").through(Key.get(TestFilter.class));
80                 filter("*.jsp").through(Key.get(TestFilter.class));
81 
82                 // Bind a servlet
83                 serve("*.html").with(TestServlet.class);
84               }
85             });
86 
87     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
88     pipeline.initPipeline(null);
89 
90     // create ourselves a mock request with test URI
91     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
92 
93     expect(requestMock.getRequestURI()).andReturn("/index.html").anyTimes();
94     expect(requestMock.getContextPath()).andReturn("").anyTimes();
95 
96     requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
97     requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
98 
99     HttpServletResponse responseMock = control.createMock(HttpServletResponse.class);
100     expect(responseMock.isCommitted()).andReturn(false).anyTimes();
101     responseMock.resetBuffer();
102     expectLastCall().anyTimes();
103 
104     FilterChain filterChain = control.createMock(FilterChain.class);
105 
106     //dispatch request
107     control.replay();
108     pipeline.dispatch(requestMock, responseMock, filterChain);
109     pipeline.destroyPipeline();
110     control.verify();
111 
112     TestServlet servlet = injector.getInstance(TestServlet.class);
113     assertEquals(2, servlet.processedUris.size());
114     assertTrue(servlet.processedUris.contains("/index.html"));
115     assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO));
116 
117     assertTrue(
118         "lifecycle states did not"
119             + " fire correct number of times-- inits: "
120             + inits
121             + "; dos: "
122             + doFilters
123             + "; destroys: "
124             + destroys,
125         inits == 1 && doFilters == 3 && destroys == 1);
126   }
127 
testDispatchThatNoFiltersFire()128   public final void testDispatchThatNoFiltersFire() throws ServletException, IOException {
129     final Injector injector =
130         Guice.createInjector(
131             new ServletModule() {
132 
133               @Override
134               protected void configureServlets() {
135                 filter("/public/*").through(TestFilter.class);
136                 filter("*.html").through(TestFilter.class);
137                 filter("*.xml").through(Key.get(TestFilter.class));
138 
139                 // These filters should never fire
140                 filter("/index/*").through(Key.get(TestFilter.class));
141                 filter("*.jsp").through(Key.get(TestFilter.class));
142               }
143             });
144 
145     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
146     pipeline.initPipeline(null);
147 
148     //create ourselves a mock request with test URI
149     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
150 
151     expect(requestMock.getRequestURI()).andReturn("/index.xhtml").anyTimes();
152     expect(requestMock.getContextPath()).andReturn("").anyTimes();
153 
154     //dispatch request
155     FilterChain filterChain = control.createMock(FilterChain.class);
156     filterChain.doFilter(requestMock, null);
157     control.replay();
158     pipeline.dispatch(requestMock, null, filterChain);
159     pipeline.destroyPipeline();
160     control.verify();
161 
162     assertTrue(
163         "lifecycle states did not "
164             + "fire correct number of times-- inits: "
165             + inits
166             + "; dos: "
167             + doFilters
168             + "; destroys: "
169             + destroys,
170         inits == 1 && doFilters == 0 && destroys == 1);
171   }
172 
testDispatchFilterPipelineWithRegexMatching()173   public final void testDispatchFilterPipelineWithRegexMatching()
174       throws ServletException, IOException {
175 
176     final Injector injector =
177         Guice.createInjector(
178             new ServletModule() {
179 
180               @Override
181               protected void configureServlets() {
182                 filterRegex("/[A-Za-z]*").through(TestFilter.class);
183                 filterRegex("/index").through(TestFilter.class);
184                 //these filters should never fire
185                 filterRegex("\\w").through(Key.get(TestFilter.class));
186               }
187             });
188 
189     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
190     pipeline.initPipeline(null);
191 
192     //create ourselves a mock request with test URI
193     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
194 
195     expect(requestMock.getRequestURI()).andReturn("/index").anyTimes();
196     expect(requestMock.getContextPath()).andReturn("").anyTimes();
197 
198     // dispatch request
199     FilterChain filterChain = control.createMock(FilterChain.class);
200     filterChain.doFilter(requestMock, null);
201     control.replay();
202     pipeline.dispatch(requestMock, null, filterChain);
203     pipeline.destroyPipeline();
204     control.verify();
205 
206     assertTrue(
207         "lifecycle states did not fire "
208             + "correct number of times-- inits: "
209             + inits
210             + "; dos: "
211             + doFilters
212             + "; destroys: "
213             + destroys,
214         inits == 1 && doFilters == 2 && destroys == 1);
215   }
216 
217   @Singleton
218   public static class TestFilter implements Filter {
219     @Override
init(FilterConfig filterConfig)220     public void init(FilterConfig filterConfig) {
221       inits++;
222     }
223 
224     @Override
doFilter( ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)225     public void doFilter(
226         ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
227         throws IOException, ServletException {
228       doFilters++;
229       filterChain.doFilter(servletRequest, servletResponse);
230     }
231 
232     @Override
destroy()233     public void destroy() {
234       destroys++;
235     }
236   }
237 
testFilterBypass()238   public final void testFilterBypass() throws ServletException, IOException {
239 
240     final Injector injector =
241         Guice.createInjector(
242             new ServletModule() {
243               @Override
244               protected void configureServlets() {
245                 filter("/protected/*").through(TestFilter.class);
246               }
247             });
248 
249     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
250     pipeline.initPipeline(null);
251     assertEquals(1, inits);
252 
253     runRequestForPath(pipeline, "/./protected/resource", true);
254     runRequestForPath(pipeline, "/protected/../resource", false);
255     runRequestForPath(pipeline, "/protected/../protected/resource", true);
256 
257     assertEquals(0, destroys);
258     pipeline.destroyPipeline();
259     assertEquals(1, destroys);
260   }
261 
runRequestForPath(FilterPipeline pipeline, String value, boolean matches)262   private void runRequestForPath(FilterPipeline pipeline, String value, boolean matches)
263       throws IOException, ServletException {
264     assertEquals(0, doFilters);
265     //create ourselves a mock request with test URI
266     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
267     expect(requestMock.getRequestURI()).andReturn(value).anyTimes();
268     expect(requestMock.getContextPath()).andReturn("").anyTimes();
269     // dispatch request
270     FilterChain filterChain = control.createMock(FilterChain.class);
271     filterChain.doFilter(requestMock, null);
272     control.replay();
273     pipeline.dispatch(requestMock, null, filterChain);
274     control.verify();
275     control.reset();
276     if (matches) {
277       assertEquals("filter was not run", 1, doFilters);
278       doFilters = 0;
279     } else {
280       assertEquals("filter was run", 0, doFilters);
281     }
282   }
283 
284   @Singleton
285   public static class TestServlet extends HttpServlet {
286     public static final String FORWARD_FROM = "/index.html";
287     public static final String FORWARD_TO = "/forwarded.html";
288     public List<String> processedUris = new ArrayList<>();
289 
290     @Override
service( HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)291     protected void service(
292         HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
293         throws ServletException, IOException {
294       String requestUri = httpServletRequest.getRequestURI();
295       processedUris.add(requestUri);
296 
297       // If the client is requesting /index.html then we forward to /forwarded.html
298       if (FORWARD_FROM.equals(requestUri)) {
299         httpServletRequest
300             .getRequestDispatcher(FORWARD_TO)
301             .forward(httpServletRequest, httpServletResponse);
302       }
303     }
304 
305     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)306     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
307         throws ServletException, IOException {
308       service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
309     }
310   }
311 
testFilterOrder()312   public void testFilterOrder() throws Exception {
313     AtomicInteger counter = new AtomicInteger();
314     final CountFilter f1 = new CountFilter(counter);
315     final CountFilter f2 = new CountFilter(counter);
316 
317     Injector injector =
318         Guice.createInjector(
319             new ServletModule() {
320               @Override
321               protected void configureServlets() {
322                 filter("/").through(f1);
323                 install(
324                     new ServletModule() {
325                       @Override
326                       protected void configureServlets() {
327                         filter("/").through(f2);
328                       }
329                     });
330               }
331             });
332 
333     HttpServletRequest request = newFakeHttpServletRequest();
334     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
335     pipeline.initPipeline(null);
336     pipeline.dispatch(request, null, newNoOpFilterChain());
337     assertEquals(0, f1.calledAt);
338     assertEquals(1, f2.calledAt);
339   }
340 
341   /** A filter that keeps count of when it was called by increment a counter. */
342   private static class CountFilter implements Filter {
343     private final AtomicInteger counter;
344     private int calledAt = -1;
345 
CountFilter(AtomicInteger counter)346     public CountFilter(AtomicInteger counter) {
347       this.counter = counter;
348     }
349 
350     @Override
destroy()351     public void destroy() {}
352 
353     @Override
doFilter(ServletRequest request, ServletResponse response, FilterChain chain)354     public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
355         throws ServletException, IOException {
356       if (calledAt != -1) {
357         fail("not expecting to be called twice");
358       }
359       calledAt = counter.getAndIncrement();
360       chain.doFilter(request, response);
361     }
362 
363     @Override
init(FilterConfig filterConfig)364     public void init(FilterConfig filterConfig) {}
365   }
366 
testFilterExceptionPrunesStack()367   public final void testFilterExceptionPrunesStack() throws Exception {
368     Injector injector =
369         Guice.createInjector(
370             new ServletModule() {
371               @Override
372               protected void configureServlets() {
373                 filter("/").through(TestFilter.class);
374                 filter("/nothing").through(TestFilter.class);
375                 filter("/").through(ThrowingFilter.class);
376               }
377             });
378 
379     HttpServletRequest request = newFakeHttpServletRequest();
380     FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
381     pipeline.initPipeline(null);
382     try {
383       pipeline.dispatch(request, null, null);
384       fail("expected exception");
385     } catch (ServletException ex) {
386       for (StackTraceElement element : ex.getStackTrace()) {
387         String className = element.getClassName();
388         assertTrue(
389             "was: " + element,
390             !className.equals(FilterChainInvocation.class.getName())
391                 && !className.equals(FilterDefinition.class.getName()));
392       }
393     }
394   }
395 
testServletExceptionPrunesStack()396   public final void testServletExceptionPrunesStack() throws Exception {
397     Injector injector =
398         Guice.createInjector(
399             new ServletModule() {
400               @Override
401               protected void configureServlets() {
402                 filter("/").through(TestFilter.class);
403                 filter("/nothing").through(TestFilter.class);
404                 serve("/").with(ThrowingServlet.class);
405               }
406             });
407 
408     HttpServletRequest request = newFakeHttpServletRequest();
409     FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
410     pipeline.initPipeline(null);
411     try {
412       pipeline.dispatch(request, null, null);
413       fail("expected exception");
414     } catch (ServletException ex) {
415       for (StackTraceElement element : ex.getStackTrace()) {
416         String className = element.getClassName();
417         assertTrue(
418             "was: " + element,
419             !className.equals(FilterChainInvocation.class.getName())
420                 && !className.equals(FilterDefinition.class.getName()));
421       }
422     }
423   }
424 
425   @Singleton
426   private static class ThrowingServlet extends HttpServlet {
427 
428     @Override
service(HttpServletRequest req, HttpServletResponse resp)429     protected void service(HttpServletRequest req, HttpServletResponse resp)
430         throws ServletException {
431       throw new ServletException("failure!");
432     }
433   }
434 
435   @Singleton
436   private static class ThrowingFilter implements Filter {
437     @Override
destroy()438     public void destroy() {}
439 
440     @Override
doFilter(ServletRequest request, ServletResponse response, FilterChain chain)441     public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
442         throws ServletException {
443       throw new ServletException("we failed!");
444     }
445 
446     @Override
init(FilterConfig filterConfig)447     public void init(FilterConfig filterConfig) {}
448   }
449 }
450