• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2008 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.servlet;
18 
19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
20 import static org.easymock.EasyMock.createMock;
21 import static org.easymock.EasyMock.expect;
22 import static org.easymock.EasyMock.replay;
23 import static org.easymock.EasyMock.verify;
24 
25 import com.google.inject.Guice;
26 import com.google.inject.Injector;
27 import com.google.inject.Key;
28 import com.google.inject.Singleton;
29 import java.io.IOException;
30 import javax.servlet.Filter;
31 import javax.servlet.FilterChain;
32 import javax.servlet.FilterConfig;
33 import javax.servlet.ServletConfig;
34 import javax.servlet.ServletException;
35 import javax.servlet.ServletRequest;
36 import javax.servlet.ServletResponse;
37 import javax.servlet.http.HttpServlet;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletResponse;
40 import junit.framework.TestCase;
41 
42 /**
43  * Tests the FilterPipeline that dispatches to guice-managed servlets, is a full integration test,
44  * with a real injector.
45  *
46  * @author Dhanji R. Prasanna (dhanji gmail com)
47  */
48 public class ServletDispatchIntegrationTest extends TestCase {
49   private static int inits, services, destroys, doFilters;
50 
51   @Override
setUp()52   public void setUp() {
53     inits = 0;
54     services = 0;
55     destroys = 0;
56     doFilters = 0;
57 
58     GuiceFilter.reset();
59   }
60 
testDispatchRequestToManagedPipelineServlets()61   public final void testDispatchRequestToManagedPipelineServlets()
62       throws ServletException, IOException {
63     final Injector injector =
64         Guice.createInjector(
65             new ServletModule() {
66 
67               @Override
68               protected void configureServlets() {
69                 serve("/*").with(TestServlet.class);
70 
71                 // These servets should never fire... (ordering test)
72                 serve("*.html").with(NeverServlet.class);
73                 serve("/test/*").with(Key.get(NeverServlet.class));
74                 serve("/index/*").with(Key.get(NeverServlet.class));
75                 serve("*.jsp").with(Key.get(NeverServlet.class));
76               }
77             });
78 
79     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
80 
81     pipeline.initPipeline(null);
82 
83     //create ourselves a mock request with test URI
84     HttpServletRequest requestMock = createMock(HttpServletRequest.class);
85 
86     expect(requestMock.getRequestURI()).andReturn("/index.html").times(1);
87     expect(requestMock.getContextPath()).andReturn("").anyTimes();
88 
89     //dispatch request
90     replay(requestMock);
91 
92     pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
93 
94     pipeline.destroyPipeline();
95 
96     verify(requestMock);
97 
98     assertTrue(
99         "lifecycle states did not fire correct number of times-- inits: "
100             + inits
101             + "; dos: "
102             + services
103             + "; destroys: "
104             + destroys,
105         inits == 2 && services == 1 && destroys == 2);
106   }
107 
testDispatchRequestToManagedPipelineWithFilter()108   public final void testDispatchRequestToManagedPipelineWithFilter()
109       throws ServletException, IOException {
110     final Injector injector =
111         Guice.createInjector(
112             new ServletModule() {
113 
114               @Override
115               protected void configureServlets() {
116                 filter("/*").through(TestFilter.class);
117 
118                 serve("/*").with(TestServlet.class);
119 
120                 // These servets should never fire...
121                 serve("*.html").with(NeverServlet.class);
122                 serve("/test/*").with(Key.get(NeverServlet.class));
123                 serve("/index/*").with(Key.get(NeverServlet.class));
124                 serve("*.jsp").with(Key.get(NeverServlet.class));
125               }
126             });
127 
128     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
129 
130     pipeline.initPipeline(null);
131 
132     //create ourselves a mock request with test URI
133     HttpServletRequest requestMock = createMock(HttpServletRequest.class);
134 
135     expect(requestMock.getRequestURI()).andReturn("/index.html").times(2);
136     expect(requestMock.getContextPath()).andReturn("").anyTimes();
137 
138     //dispatch request
139     replay(requestMock);
140 
141     pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
142 
143     pipeline.destroyPipeline();
144 
145     verify(requestMock);
146 
147     assertTrue(
148         "lifecycle states did not fire correct number of times-- inits: "
149             + inits
150             + "; dos: "
151             + services
152             + "; destroys: "
153             + destroys
154             + "; doFilters: "
155             + doFilters,
156         inits == 3 && services == 1 && destroys == 3 && doFilters == 1);
157   }
158 
159   @Singleton
160   public static class TestServlet extends HttpServlet {
161     @Override
init(ServletConfig filterConfig)162     public void init(ServletConfig filterConfig) throws ServletException {
163       inits++;
164     }
165 
166     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)167     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
168         throws IOException, ServletException {
169       services++;
170     }
171 
172     @Override
destroy()173     public void destroy() {
174       destroys++;
175     }
176   }
177 
178   @Singleton
179   public static class NeverServlet extends HttpServlet {
180     @Override
init(ServletConfig filterConfig)181     public void init(ServletConfig filterConfig) throws ServletException {
182       inits++;
183     }
184 
185     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)186     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
187         throws IOException, ServletException {
188       fail("NeverServlet was fired, when it should not have been.");
189     }
190 
191     @Override
destroy()192     public void destroy() {
193       destroys++;
194     }
195   }
196 
197   @Singleton
198   public static class TestFilter implements Filter {
199     @Override
init(FilterConfig filterConfig)200     public void init(FilterConfig filterConfig) throws ServletException {
201       inits++;
202     }
203 
204     @Override
doFilter( ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)205     public void doFilter(
206         ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
207         throws IOException, ServletException {
208       doFilters++;
209       filterChain.doFilter(servletRequest, servletResponse);
210     }
211 
212     @Override
destroy()213     public void destroy() {
214       destroys++;
215     }
216   }
217 
218   @Singleton
219   public static class ForwardingServlet extends HttpServlet {
220     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)221     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
222         throws IOException, ServletException {
223       final HttpServletRequest request = (HttpServletRequest) servletRequest;
224 
225       request.getRequestDispatcher("/blah.jsp").forward(servletRequest, servletResponse);
226     }
227   }
228 
229   @Singleton
230   public static class ForwardedServlet extends HttpServlet {
231     static int forwardedTo = 0;
232 
233     // Reset for test.
ForwardedServlet()234     public ForwardedServlet() {
235       forwardedTo = 0;
236     }
237 
238     @Override
service(ServletRequest servletRequest, ServletResponse servletResponse)239     public void service(ServletRequest servletRequest, ServletResponse servletResponse)
240         throws IOException, ServletException {
241       final HttpServletRequest request = (HttpServletRequest) servletRequest;
242 
243       assertTrue((Boolean) request.getAttribute(REQUEST_DISPATCHER_REQUEST));
244       forwardedTo++;
245     }
246   }
247 
testForwardUsingRequestDispatcher()248   public void testForwardUsingRequestDispatcher() throws IOException, ServletException {
249     Guice.createInjector(
250         new ServletModule() {
251           @Override
252           protected void configureServlets() {
253             serve("/").with(ForwardingServlet.class);
254             serve("/blah.jsp").with(ForwardedServlet.class);
255           }
256         });
257 
258     final HttpServletRequest requestMock = createMock(HttpServletRequest.class);
259     HttpServletResponse responseMock = createMock(HttpServletResponse.class);
260     expect(requestMock.getRequestURI()).andReturn("/").anyTimes();
261     expect(requestMock.getContextPath()).andReturn("").anyTimes();
262 
263     requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
264     expect(requestMock.getAttribute(REQUEST_DISPATCHER_REQUEST)).andReturn(true);
265     requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
266 
267     expect(responseMock.isCommitted()).andReturn(false);
268     responseMock.resetBuffer();
269 
270     replay(requestMock, responseMock);
271 
272     new GuiceFilter().doFilter(requestMock, responseMock, createMock(FilterChain.class));
273 
274     assertEquals("Incorrect number of forwards", 1, ForwardedServlet.forwardedTo);
275     verify(requestMock, responseMock);
276   }
277 
testQueryInRequestUri_regex()278   public final void testQueryInRequestUri_regex() throws Exception {
279     final Injector injector =
280         Guice.createInjector(
281             new ServletModule() {
282 
283               @Override
284               protected void configureServlets() {
285                 filterRegex("(.)*\\.html").through(TestFilter.class);
286 
287                 serveRegex("(.)*\\.html").with(TestServlet.class);
288               }
289             });
290 
291     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
292 
293     pipeline.initPipeline(null);
294 
295     //create ourselves a mock request with test URI
296     HttpServletRequest requestMock = createMock(HttpServletRequest.class);
297 
298     expect(requestMock.getRequestURI()).andReturn("/index.html?query=params").atLeastOnce();
299     expect(requestMock.getContextPath()).andReturn("").anyTimes();
300 
301     //dispatch request
302     replay(requestMock);
303 
304     pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
305 
306     pipeline.destroyPipeline();
307 
308     verify(requestMock);
309 
310     assertEquals(1, doFilters);
311     assertEquals(1, services);
312   }
313 
testQueryInRequestUri()314   public final void testQueryInRequestUri() throws Exception {
315     final Injector injector =
316         Guice.createInjector(
317             new ServletModule() {
318 
319               @Override
320               protected void configureServlets() {
321                 filter("/index.html").through(TestFilter.class);
322 
323                 serve("/index.html").with(TestServlet.class);
324               }
325             });
326 
327     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
328 
329     pipeline.initPipeline(null);
330 
331     //create ourselves a mock request with test URI
332     HttpServletRequest requestMock = createMock(HttpServletRequest.class);
333 
334     expect(requestMock.getRequestURI()).andReturn("/index.html?query=params").atLeastOnce();
335     expect(requestMock.getContextPath()).andReturn("").anyTimes();
336 
337     //dispatch request
338     replay(requestMock);
339 
340     pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
341 
342     pipeline.destroyPipeline();
343 
344     verify(requestMock);
345 
346     assertEquals(1, doFilters);
347     assertEquals(1, services);
348   }
349 }
350