• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2008 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.inject.servlet;
17 
18 import com.google.common.base.Preconditions;
19 import com.google.common.collect.Lists;
20 import com.google.common.collect.Sets;
21 import com.google.inject.Binding;
22 import com.google.inject.Inject;
23 import com.google.inject.Injector;
24 import com.google.inject.Singleton;
25 import com.google.inject.TypeLiteral;
26 import java.io.IOException;
27 import java.util.List;
28 import java.util.Set;
29 import javax.servlet.RequestDispatcher;
30 import javax.servlet.ServletContext;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServlet;
35 import javax.servlet.http.HttpServletRequest;
36 import javax.servlet.http.HttpServletRequestWrapper;
37 
38 /**
39  * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for
40  * filters.
41  *
42  * @author dhanji@gmail.com (Dhanji R. Prasanna)
43  */
44 @Singleton
45 class ManagedServletPipeline {
46   private final ServletDefinition[] servletDefinitions;
47   private static final TypeLiteral<ServletDefinition> SERVLET_DEFS =
48       TypeLiteral.get(ServletDefinition.class);
49 
50   @Inject
ManagedServletPipeline(Injector injector)51   public ManagedServletPipeline(Injector injector) {
52     this.servletDefinitions = collectServletDefinitions(injector);
53   }
54 
hasServletsMapped()55   boolean hasServletsMapped() {
56     return servletDefinitions.length > 0;
57   }
58 
59   /**
60    * Introspects the injector and collects all instances of bound {@code List<ServletDefinition>}
61    * into a master list.
62    *
63    * <p>We have a guarantee that {@link com.google.inject.Injector#getBindings()} returns a map that
64    * preserves insertion order in entry-set iterators.
65    */
collectServletDefinitions(Injector injector)66   private ServletDefinition[] collectServletDefinitions(Injector injector) {
67     List<ServletDefinition> servletDefinitions = Lists.newArrayList();
68     for (Binding<ServletDefinition> entry : injector.findBindingsByType(SERVLET_DEFS)) {
69       servletDefinitions.add(entry.getProvider().get());
70     }
71 
72     // Copy to a fixed size array for speed.
73     return servletDefinitions.toArray(new ServletDefinition[servletDefinitions.size()]);
74   }
75 
init(ServletContext servletContext, Injector injector)76   public void init(ServletContext servletContext, Injector injector) throws ServletException {
77     Set<HttpServlet> initializedSoFar = Sets.newIdentityHashSet();
78 
79     for (ServletDefinition servletDefinition : servletDefinitions) {
80       servletDefinition.init(servletContext, injector, initializedSoFar);
81     }
82   }
83 
service(ServletRequest request, ServletResponse response)84   public boolean service(ServletRequest request, ServletResponse response)
85       throws IOException, ServletException {
86 
87     //stop at the first matching servlet and service
88     for (ServletDefinition servletDefinition : servletDefinitions) {
89       if (servletDefinition.service(request, response)) {
90         return true;
91       }
92     }
93 
94     //there was no match...
95     return false;
96   }
97 
destroy()98   public void destroy() {
99     Set<HttpServlet> destroyedSoFar = Sets.newIdentityHashSet();
100     for (ServletDefinition servletDefinition : servletDefinitions) {
101       servletDefinition.destroy(destroyedSoFar);
102     }
103   }
104 
105   /**
106    * @return Returns a request dispatcher wrapped with a servlet mapped to the given path or null if
107    *     no mapping was found.
108    */
getRequestDispatcher(String path)109   RequestDispatcher getRequestDispatcher(String path) {
110     final String newRequestUri = path;
111 
112     // TODO(dhanji): check servlet spec to see if the following is legal or not.
113     // Need to strip query string if requested...
114 
115     for (final ServletDefinition servletDefinition : servletDefinitions) {
116       if (servletDefinition.shouldServe(path)) {
117         return new RequestDispatcher() {
118           @Override
119           public void forward(ServletRequest servletRequest, ServletResponse servletResponse)
120               throws ServletException, IOException {
121             Preconditions.checkState(
122                 !servletResponse.isCommitted(),
123                 "Response has been committed--you can only call forward before"
124                     + " committing the response (hint: don't flush buffers)");
125 
126             // clear buffer before forwarding
127             servletResponse.resetBuffer();
128 
129             ServletRequest requestToProcess;
130             if (servletRequest instanceof HttpServletRequest) {
131               requestToProcess = wrapRequest((HttpServletRequest) servletRequest, newRequestUri);
132             } else {
133               // This should never happen, but instead of throwing an exception
134               // we will allow a happy case pass thru for maximum tolerance to
135               // legacy (and internal) code.
136               requestToProcess = servletRequest;
137             }
138 
139             // now dispatch to the servlet
140             doServiceImpl(servletDefinition, requestToProcess, servletResponse);
141           }
142 
143           @Override
144           public void include(ServletRequest servletRequest, ServletResponse servletResponse)
145               throws ServletException, IOException {
146             // route to the target servlet
147             doServiceImpl(servletDefinition, servletRequest, servletResponse);
148           }
149 
150           private void doServiceImpl(
151               ServletDefinition servletDefinition,
152               ServletRequest servletRequest,
153               ServletResponse servletResponse)
154               throws ServletException, IOException {
155             servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
156 
157             try {
158               servletDefinition.doService(servletRequest, servletResponse);
159             } finally {
160               servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
161             }
162           }
163         };
164       }
165     }
166 
167     //otherwise, can't process
168     return null;
169   }
170 
171   // visible for testing
172   static HttpServletRequest wrapRequest(HttpServletRequest request, String newUri) {
173     return new RequestDispatcherRequestWrapper(request, newUri);
174   }
175 
176   /**
177    * A Marker constant attribute that when present in the request indicates to Guice servlet that
178    * this request has been generated by a request dispatcher rather than the servlet pipeline. In
179    * accordance with section 8.4.2 of the Servlet 2.4 specification.
180    */
181   public static final String REQUEST_DISPATCHER_REQUEST = "javax.servlet.forward.servlet_path";
182 
183   private static class RequestDispatcherRequestWrapper extends HttpServletRequestWrapper {
184     private final String newRequestUri;
185 
186     public RequestDispatcherRequestWrapper(
187         HttpServletRequest servletRequest, String newRequestUri) {
188       super(servletRequest);
189       this.newRequestUri = newRequestUri;
190     }
191 
192     @Override
193     public String getRequestURI() {
194       return newRequestUri;
195     }
196 
197     @Override
198     public StringBuffer getRequestURL() {
199       StringBuffer url = new StringBuffer();
200       String scheme = getScheme();
201       int port = getServerPort();
202 
203       url.append(scheme);
204       url.append("://");
205       url.append(getServerName());
206       // port might be -1 in some cases (see java.net.URL.getPort)
207       if (port > 0
208           && (("http".equals(scheme) && (port != 80))
209               || ("https".equals(scheme) && (port != 443)))) {
210         url.append(':');
211         url.append(port);
212       }
213       url.append(getRequestURI());
214 
215       return (url);
216     }
217   }
218 }
219