• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (C) 2008 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.inject.servlet;
17 
18 import com.google.common.base.Preconditions;
19 import com.google.common.collect.Lists;
20 import com.google.common.collect.Sets;
21 import com.google.inject.Binding;
22 import com.google.inject.Inject;
23 import com.google.inject.Injector;
24 import com.google.inject.Singleton;
25 import com.google.inject.TypeLiteral;
26 
27 import java.io.IOException;
28 import java.util.List;
29 import java.util.Set;
30 
31 import javax.servlet.RequestDispatcher;
32 import javax.servlet.ServletContext;
33 import javax.servlet.ServletException;
34 import javax.servlet.ServletRequest;
35 import javax.servlet.ServletResponse;
36 import javax.servlet.http.HttpServlet;
37 import javax.servlet.http.HttpServletRequest;
38 import javax.servlet.http.HttpServletRequestWrapper;
39 
40 /**
41  * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for
42  * filters.
43  *
44  * @author dhanji@gmail.com (Dhanji R. Prasanna)
45  */
46 @Singleton
47 class ManagedServletPipeline {
48   private final ServletDefinition[] servletDefinitions;
49   private static final TypeLiteral<ServletDefinition> SERVLET_DEFS =
50       TypeLiteral.get(ServletDefinition.class);
51 
52   @Inject
ManagedServletPipeline(Injector injector)53   public ManagedServletPipeline(Injector injector) {
54     this.servletDefinitions = collectServletDefinitions(injector);
55   }
56 
hasServletsMapped()57   boolean hasServletsMapped() {
58     return servletDefinitions.length > 0;
59   }
60 
61   /**
62    * Introspects the injector and collects all instances of bound {@code List<ServletDefinition>}
63    * into a master list.
64    *
65    * We have a guarantee that {@link com.google.inject.Injector#getBindings()} returns a map
66    * that preserves insertion order in entry-set iterators.
67    */
collectServletDefinitions(Injector injector)68   private ServletDefinition[] collectServletDefinitions(Injector injector) {
69     List<ServletDefinition> servletDefinitions = Lists.newArrayList();
70     for (Binding<ServletDefinition> entry : injector.findBindingsByType(SERVLET_DEFS)) {
71         servletDefinitions.add(entry.getProvider().get());
72     }
73 
74     // Copy to a fixed size array for speed.
75     return servletDefinitions.toArray(new ServletDefinition[servletDefinitions.size()]);
76   }
77 
init(ServletContext servletContext, Injector injector)78   public void init(ServletContext servletContext, Injector injector) throws ServletException {
79     Set<HttpServlet> initializedSoFar = Sets.newIdentityHashSet();
80 
81     for (ServletDefinition servletDefinition : servletDefinitions) {
82       servletDefinition.init(servletContext, injector, initializedSoFar);
83     }
84   }
85 
service(ServletRequest request, ServletResponse response)86   public boolean service(ServletRequest request, ServletResponse response)
87       throws IOException, ServletException {
88 
89     //stop at the first matching servlet and service
90     for (ServletDefinition servletDefinition : servletDefinitions) {
91       if (servletDefinition.service(request, response)) {
92         return true;
93       }
94     }
95 
96     //there was no match...
97     return false;
98   }
99 
destroy()100   public void destroy() {
101     Set<HttpServlet> destroyedSoFar = Sets.newIdentityHashSet();
102     for (ServletDefinition servletDefinition : servletDefinitions) {
103       servletDefinition.destroy(destroyedSoFar);
104     }
105   }
106 
107   /**
108    * @return Returns a request dispatcher wrapped with a servlet mapped to
109    * the given path or null if no mapping was found.
110    */
getRequestDispatcher(String path)111   RequestDispatcher getRequestDispatcher(String path) {
112     final String newRequestUri = path;
113 
114     // TODO(dhanji): check servlet spec to see if the following is legal or not.
115     // Need to strip query string if requested...
116 
117     for (final ServletDefinition servletDefinition : servletDefinitions) {
118       if (servletDefinition.shouldServe(path)) {
119         return new RequestDispatcher() {
120           public void forward(ServletRequest servletRequest, ServletResponse servletResponse)
121               throws ServletException, IOException {
122             Preconditions.checkState(!servletResponse.isCommitted(),
123                 "Response has been committed--you can only call forward before"
124                 + " committing the response (hint: don't flush buffers)");
125 
126             // clear buffer before forwarding
127             servletResponse.resetBuffer();
128 
129             ServletRequest requestToProcess;
130             if (servletRequest instanceof HttpServletRequest) {
131                requestToProcess = wrapRequest((HttpServletRequest)servletRequest, newRequestUri);
132             } else {
133               // This should never happen, but instead of throwing an exception
134               // we will allow a happy case pass thru for maximum tolerance to
135               // legacy (and internal) code.
136               requestToProcess = servletRequest;
137             }
138 
139             // now dispatch to the servlet
140             doServiceImpl(servletDefinition, requestToProcess, servletResponse);
141           }
142 
143           public void include(ServletRequest servletRequest, ServletResponse servletResponse)
144               throws ServletException, IOException {
145             // route to the target servlet
146             doServiceImpl(servletDefinition, servletRequest, servletResponse);
147           }
148 
149           private void doServiceImpl(ServletDefinition servletDefinition, ServletRequest servletRequest,
150               ServletResponse servletResponse) throws ServletException, IOException {
151             servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
152 
153             try {
154               servletDefinition.doService(servletRequest, servletResponse);
155             } finally {
156               servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
157             }
158           }
159         };
160       }
161     }
162 
163     //otherwise, can't process
164     return null;
165   }
166 
167   // visible for testing
168   static HttpServletRequest wrapRequest(HttpServletRequest request, String newUri) {
169     return new RequestDispatcherRequestWrapper(request, newUri);
170   }
171 
172   /**
173    * A Marker constant attribute that when present in the request indicates to Guice servlet that
174    * this request has been generated by a request dispatcher rather than the servlet pipeline.
175    * In accordance with section 8.4.2 of the Servlet 2.4 specification.
176    */
177   public static final String REQUEST_DISPATCHER_REQUEST = "javax.servlet.forward.servlet_path";
178 
179   private static class RequestDispatcherRequestWrapper extends HttpServletRequestWrapper {
180     private final String newRequestUri;
181 
182     public RequestDispatcherRequestWrapper(HttpServletRequest servletRequest, String newRequestUri) {
183       super(servletRequest);
184       this.newRequestUri = newRequestUri;
185     }
186 
187     @Override
188     public String getRequestURI() {
189       return newRequestUri;
190     }
191 
192     @Override
193     public StringBuffer getRequestURL() {
194       StringBuffer url = new StringBuffer();
195       String scheme = getScheme();
196       int port = getServerPort();
197 
198       url.append(scheme);
199       url.append("://");
200       url.append(getServerName());
201       // port might be -1 in some cases (see java.net.URL.getPort)
202       if (port > 0 &&
203           (("http".equals(scheme) && (port != 80)) ||
204            ("https".equals(scheme) && (port != 443)))) {
205         url.append(':');
206         url.append(port);
207       }
208       url.append(getRequestURI());
209 
210       return (url);
211     }
212   }
213 }
214