• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (C) 2011 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.servlet;
18 
19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
20 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest;
21 import static com.google.inject.servlet.ServletTestUtils.newNoOpFilterChain;
22 import static org.easymock.EasyMock.expect;
23 import static org.easymock.EasyMock.expectLastCall;
24 
25 import com.google.inject.Guice;
26 import com.google.inject.Injector;
27 import com.google.inject.Key;
28 import com.google.inject.Singleton;
29 
30 import junit.framework.TestCase;
31 
32 import org.easymock.EasyMock;
33 import org.easymock.IMocksControl;
34 
35 import java.io.IOException;
36 import java.util.ArrayList;
37 import java.util.List;
38 import java.util.concurrent.atomic.AtomicInteger;
39 
40 import javax.servlet.Filter;
41 import javax.servlet.FilterChain;
42 import javax.servlet.FilterConfig;
43 import javax.servlet.ServletException;
44 import javax.servlet.ServletRequest;
45 import javax.servlet.ServletResponse;
46 import javax.servlet.http.HttpServlet;
47 import javax.servlet.http.HttpServletRequest;
48 import javax.servlet.http.HttpServletResponse;
49 
50 /**
51  *
52  * This tests that filter stage of the pipeline dispatches
53  * correctly to guice-managed filters.
54  *
55  * WARNING(dhanji): Non-parallelizable test =(
56  *
57  * @author dhanji@gmail.com (Dhanji R. Prasanna)
58  */
59 public class FilterDispatchIntegrationTest extends TestCase {
60     private static int inits, doFilters, destroys;
61 
62   private IMocksControl control;
63 
64   @Override
setUp()65   public final void setUp() {
66     inits = 0;
67     doFilters = 0;
68     destroys = 0;
69     control = EasyMock.createControl();
70     GuiceFilter.reset();
71   }
72 
73 
testDispatchRequestToManagedPipeline()74   public final void testDispatchRequestToManagedPipeline() throws ServletException, IOException {
75     final Injector injector = Guice.createInjector(new ServletModule() {
76 
77       @Override
78       protected void configureServlets() {
79         filter("/*").through(TestFilter.class);
80         filter("*.html").through(TestFilter.class);
81         filter("/*").through(Key.get(TestFilter.class));
82 
83         // These filters should never fire
84         filter("/index/*").through(Key.get(TestFilter.class));
85         filter("*.jsp").through(Key.get(TestFilter.class));
86 
87         // Bind a servlet
88         serve("*.html").with(TestServlet.class);
89       }
90     });
91 
92     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
93     pipeline.initPipeline(null);
94 
95     // create ourselves a mock request with test URI
96     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
97 
98     expect(requestMock.getRequestURI())
99             .andReturn("/index.html")
100             .anyTimes();
101     expect(requestMock.getContextPath())
102         .andReturn("")
103         .anyTimes();
104 
105     requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
106     requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
107 
108     HttpServletResponse responseMock = control.createMock(HttpServletResponse.class);
109     expect(responseMock.isCommitted())
110         .andReturn(false)
111         .anyTimes();
112     responseMock.resetBuffer();
113     expectLastCall().anyTimes();
114 
115     FilterChain filterChain = control.createMock(FilterChain.class);
116 
117     //dispatch request
118     control.replay();
119     pipeline.dispatch(requestMock, responseMock, filterChain);
120     pipeline.destroyPipeline();
121     control.verify();
122 
123     TestServlet servlet = injector.getInstance(TestServlet.class);
124     assertEquals(2, servlet.processedUris.size());
125     assertTrue(servlet.processedUris.contains("/index.html"));
126     assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO));
127 
128     assertTrue("lifecycle states did not"
129         + " fire correct number of times-- inits: " + inits + "; dos: " + doFilters
130         + "; destroys: " + destroys, inits == 1 && doFilters == 3 && destroys == 1);
131   }
132 
testDispatchThatNoFiltersFire()133   public final void testDispatchThatNoFiltersFire() throws ServletException, IOException {
134     final Injector injector = Guice.createInjector(new ServletModule() {
135 
136       @Override
137       protected void configureServlets() {
138         filter("/public/*").through(TestFilter.class);
139         filter("*.html").through(TestFilter.class);
140         filter("*.xml").through(Key.get(TestFilter.class));
141 
142         // These filters should never fire
143         filter("/index/*").through(Key.get(TestFilter.class));
144         filter("*.jsp").through(Key.get(TestFilter.class));
145       }
146     });
147 
148     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
149     pipeline.initPipeline(null);
150 
151     //create ourselves a mock request with test URI
152     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
153 
154     expect(requestMock.getRequestURI())
155             .andReturn("/index.xhtml")
156             .anyTimes();
157     expect(requestMock.getContextPath())
158         .andReturn("")
159         .anyTimes();
160 
161     //dispatch request
162     FilterChain filterChain = control.createMock(FilterChain.class);
163     filterChain.doFilter(requestMock, null);
164     control.replay();
165     pipeline.dispatch(requestMock, null, filterChain);
166     pipeline.destroyPipeline();
167     control.verify();
168 
169     assertTrue("lifecycle states did not "
170         + "fire correct number of times-- inits: " + inits + "; dos: " + doFilters
171         + "; destroys: " + destroys,
172         inits == 1 && doFilters == 0 && destroys == 1);
173   }
174 
testDispatchFilterPipelineWithRegexMatching()175   public final void testDispatchFilterPipelineWithRegexMatching() throws ServletException,
176       IOException {
177 
178     final Injector injector = Guice.createInjector(new ServletModule() {
179 
180       @Override
181       protected void configureServlets() {
182         filterRegex("/[A-Za-z]*").through(TestFilter.class);
183         filterRegex("/index").through(TestFilter.class);
184         //these filters should never fire
185         filterRegex("\\w").through(Key.get(TestFilter.class));
186       }
187     });
188 
189     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
190     pipeline.initPipeline(null);
191 
192     //create ourselves a mock request with test URI
193     HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
194 
195     expect(requestMock.getRequestURI())
196             .andReturn("/index")
197             .anyTimes();
198     expect(requestMock.getContextPath())
199         .andReturn("")
200         .anyTimes();
201 
202     // dispatch request
203     FilterChain filterChain = control.createMock(FilterChain.class);
204     filterChain.doFilter(requestMock, null);
205     control.replay();
206     pipeline.dispatch(requestMock, null, filterChain);
207     pipeline.destroyPipeline();
208     control.verify();
209 
210     assertTrue("lifecycle states did not fire "
211         + "correct number of times-- inits: " + inits + "; dos: " + doFilters
212         + "; destroys: " + destroys,
213         inits == 1 && doFilters == 2 && destroys == 1);
214   }
215 
216   @Singleton
217   public static class TestFilter implements Filter {
init(FilterConfig filterConfig)218     public void init(FilterConfig filterConfig) {
219       inits++;
220     }
221 
doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)222     public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
223         FilterChain filterChain) throws IOException, ServletException {
224       doFilters++;
225       filterChain.doFilter(servletRequest, servletResponse);
226     }
227 
destroy()228     public void destroy() {
229       destroys++;
230     }
231   }
232 
233   @Singleton
234   public static class TestServlet extends HttpServlet {
235     public static final String FORWARD_FROM = "/index.html";
236     public static final String FORWARD_TO = "/forwarded.html";
237     public List<String> processedUris = new ArrayList<String>();
238 
239     @Override
service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)240     protected void service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
241         throws ServletException, IOException {
242       String requestUri = httpServletRequest.getRequestURI();
243       processedUris.add(requestUri);
244 
245       // If the client is requesting /index.html then we forward to /forwarded.html
246       if (FORWARD_FROM.equals(requestUri)) {
247         httpServletRequest.getRequestDispatcher(FORWARD_TO)
248             .forward(httpServletRequest, httpServletResponse);
249       }
250     }
251 
252     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)253     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
254         throws ServletException, IOException {
255       service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
256     }
257   }
258 
testFilterOrder()259   public void testFilterOrder() throws Exception {
260     AtomicInteger counter = new AtomicInteger();
261     final CountFilter f1 = new CountFilter(counter);
262     final CountFilter f2 = new CountFilter(counter);
263 
264     Injector injector = Guice.createInjector(new ServletModule() {
265       @Override
266       protected void configureServlets() {
267         filter("/").through(f1);
268         install(new ServletModule() {
269           @Override
270           protected void configureServlets() {
271             filter("/").through(f2);
272           }
273         });
274       }
275     });
276 
277     HttpServletRequest request = newFakeHttpServletRequest();
278     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
279     pipeline.initPipeline(null);
280     pipeline.dispatch(request, null, newNoOpFilterChain());
281     assertEquals(0, f1.calledAt);
282     assertEquals(1, f2.calledAt);
283   }
284 
285   /** A filter that keeps count of when it was called by increment a counter. */
286   private static class CountFilter implements Filter {
287     private final AtomicInteger counter;
288     private int calledAt = -1;
289 
CountFilter(AtomicInteger counter)290     public CountFilter(AtomicInteger counter) {
291       this.counter = counter;
292     }
293 
destroy()294     public void destroy() {
295     }
296 
doFilter(ServletRequest request, ServletResponse response, FilterChain chain)297     public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
298         throws ServletException, IOException {
299       if (calledAt != -1) {
300         fail("not expecting to be called twice");
301       }
302       calledAt = counter.getAndIncrement();
303       chain.doFilter(request, response);
304     }
305 
init(FilterConfig filterConfig)306     public void init(FilterConfig filterConfig) {}
307   }
308 
testFilterExceptionPrunesStack()309   public final void testFilterExceptionPrunesStack() throws Exception {
310     Injector injector = Guice.createInjector(new ServletModule() {
311       @Override
312       protected void configureServlets() {
313         filter("/").through(TestFilter.class);
314         filter("/nothing").through(TestFilter.class);
315         filter("/").through(ThrowingFilter.class);
316       }
317     });
318 
319     HttpServletRequest request = newFakeHttpServletRequest();
320     FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
321     pipeline.initPipeline(null);
322     try {
323       pipeline.dispatch(request, null, null);
324       fail("expected exception");
325     } catch(ServletException ex) {
326       for (StackTraceElement element : ex.getStackTrace()) {
327         String className = element.getClassName();
328         assertTrue("was: " + element,
329             !className.equals(FilterChainInvocation.class.getName())
330             && !className.equals(FilterDefinition.class.getName()));
331       }
332     }
333   }
334 
testServletExceptionPrunesStack()335   public final void testServletExceptionPrunesStack() throws Exception {
336     Injector injector = Guice.createInjector(new ServletModule() {
337       @Override
338       protected void configureServlets() {
339         filter("/").through(TestFilter.class);
340         filter("/nothing").through(TestFilter.class);
341         serve("/").with(ThrowingServlet.class);
342       }
343     });
344 
345     HttpServletRequest request = newFakeHttpServletRequest();
346     FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
347     pipeline.initPipeline(null);
348     try {
349       pipeline.dispatch(request, null, null);
350       fail("expected exception");
351     } catch(ServletException ex) {
352       for (StackTraceElement element : ex.getStackTrace()) {
353         String className = element.getClassName();
354         assertTrue("was: " + element,
355             !className.equals(FilterChainInvocation.class.getName())
356             && !className.equals(FilterDefinition.class.getName()));
357       }
358     }
359   }
360 
361   @Singleton
362   private static class ThrowingServlet extends HttpServlet {
363 
364     @Override
service(HttpServletRequest req, HttpServletResponse resp)365     protected void service(HttpServletRequest req, HttpServletResponse resp)
366         throws ServletException {
367       throw new ServletException("failure!");
368     }
369 
370   }
371 
372 
373   @Singleton
374   private static class ThrowingFilter implements Filter {
375     @Override
destroy()376     public void destroy() {
377     }
378 
379     @Override
doFilter(ServletRequest request, ServletResponse response, FilterChain chain)380     public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
381         throws ServletException {
382       throw new ServletException("we failed!");
383     }
384 
385     @Override
init(FilterConfig filterConfig)386     public void init(FilterConfig filterConfig) {
387 
388     }
389 
390   }
391 }
392