1 /* 2 * Copyright (C) 2008 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.google.inject.servlet; 17 18 import com.google.common.base.Preconditions; 19 import com.google.common.collect.Lists; 20 import com.google.common.collect.Sets; 21 import com.google.inject.Binding; 22 import com.google.inject.Inject; 23 import com.google.inject.Injector; 24 import com.google.inject.Singleton; 25 import com.google.inject.TypeLiteral; 26 import java.io.IOException; 27 import java.util.List; 28 import java.util.Set; 29 import javax.servlet.RequestDispatcher; 30 import javax.servlet.ServletContext; 31 import javax.servlet.ServletException; 32 import javax.servlet.ServletRequest; 33 import javax.servlet.ServletResponse; 34 import javax.servlet.http.HttpServlet; 35 import javax.servlet.http.HttpServletRequest; 36 import javax.servlet.http.HttpServletRequestWrapper; 37 38 /** 39 * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for 40 * filters. 41 * 42 * @author dhanji@gmail.com (Dhanji R. Prasanna) 43 */ 44 @Singleton 45 class ManagedServletPipeline { 46 private final ServletDefinition[] servletDefinitions; 47 private static final TypeLiteral<ServletDefinition> SERVLET_DEFS = 48 TypeLiteral.get(ServletDefinition.class); 49 50 @Inject ManagedServletPipeline(Injector injector)51 public ManagedServletPipeline(Injector injector) { 52 this.servletDefinitions = collectServletDefinitions(injector); 53 } 54 hasServletsMapped()55 boolean hasServletsMapped() { 56 return servletDefinitions.length > 0; 57 } 58 59 /** 60 * Introspects the injector and collects all instances of bound {@code List<ServletDefinition>} 61 * into a master list. 62 * 63 * <p>We have a guarantee that {@link com.google.inject.Injector#getBindings()} returns a map that 64 * preserves insertion order in entry-set iterators. 65 */ collectServletDefinitions(Injector injector)66 private ServletDefinition[] collectServletDefinitions(Injector injector) { 67 List<ServletDefinition> servletDefinitions = Lists.newArrayList(); 68 for (Binding<ServletDefinition> entry : injector.findBindingsByType(SERVLET_DEFS)) { 69 servletDefinitions.add(entry.getProvider().get()); 70 } 71 72 // Copy to a fixed size array for speed. 73 return servletDefinitions.toArray(new ServletDefinition[servletDefinitions.size()]); 74 } 75 init(ServletContext servletContext, Injector injector)76 public void init(ServletContext servletContext, Injector injector) throws ServletException { 77 Set<HttpServlet> initializedSoFar = Sets.newIdentityHashSet(); 78 79 for (ServletDefinition servletDefinition : servletDefinitions) { 80 servletDefinition.init(servletContext, injector, initializedSoFar); 81 } 82 } 83 service(ServletRequest request, ServletResponse response)84 public boolean service(ServletRequest request, ServletResponse response) 85 throws IOException, ServletException { 86 87 //stop at the first matching servlet and service 88 for (ServletDefinition servletDefinition : servletDefinitions) { 89 if (servletDefinition.service(request, response)) { 90 return true; 91 } 92 } 93 94 //there was no match... 95 return false; 96 } 97 destroy()98 public void destroy() { 99 Set<HttpServlet> destroyedSoFar = Sets.newIdentityHashSet(); 100 for (ServletDefinition servletDefinition : servletDefinitions) { 101 servletDefinition.destroy(destroyedSoFar); 102 } 103 } 104 105 /** 106 * @return Returns a request dispatcher wrapped with a servlet mapped to the given path or null if 107 * no mapping was found. 108 */ getRequestDispatcher(String path)109 RequestDispatcher getRequestDispatcher(String path) { 110 final String newRequestUri = path; 111 112 // TODO(dhanji): check servlet spec to see if the following is legal or not. 113 // Need to strip query string if requested... 114 115 for (final ServletDefinition servletDefinition : servletDefinitions) { 116 if (servletDefinition.shouldServe(path)) { 117 return new RequestDispatcher() { 118 @Override 119 public void forward(ServletRequest servletRequest, ServletResponse servletResponse) 120 throws ServletException, IOException { 121 Preconditions.checkState( 122 !servletResponse.isCommitted(), 123 "Response has been committed--you can only call forward before" 124 + " committing the response (hint: don't flush buffers)"); 125 126 // clear buffer before forwarding 127 servletResponse.resetBuffer(); 128 129 ServletRequest requestToProcess; 130 if (servletRequest instanceof HttpServletRequest) { 131 requestToProcess = wrapRequest((HttpServletRequest) servletRequest, newRequestUri); 132 } else { 133 // This should never happen, but instead of throwing an exception 134 // we will allow a happy case pass thru for maximum tolerance to 135 // legacy (and internal) code. 136 requestToProcess = servletRequest; 137 } 138 139 // now dispatch to the servlet 140 doServiceImpl(servletDefinition, requestToProcess, servletResponse); 141 } 142 143 @Override 144 public void include(ServletRequest servletRequest, ServletResponse servletResponse) 145 throws ServletException, IOException { 146 // route to the target servlet 147 doServiceImpl(servletDefinition, servletRequest, servletResponse); 148 } 149 150 private void doServiceImpl( 151 ServletDefinition servletDefinition, 152 ServletRequest servletRequest, 153 ServletResponse servletResponse) 154 throws ServletException, IOException { 155 servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE); 156 157 try { 158 servletDefinition.doService(servletRequest, servletResponse); 159 } finally { 160 servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST); 161 } 162 } 163 }; 164 } 165 } 166 167 //otherwise, can't process 168 return null; 169 } 170 171 // visible for testing 172 static HttpServletRequest wrapRequest(HttpServletRequest request, String newUri) { 173 return new RequestDispatcherRequestWrapper(request, newUri); 174 } 175 176 /** 177 * A Marker constant attribute that when present in the request indicates to Guice servlet that 178 * this request has been generated by a request dispatcher rather than the servlet pipeline. In 179 * accordance with section 8.4.2 of the Servlet 2.4 specification. 180 */ 181 public static final String REQUEST_DISPATCHER_REQUEST = "javax.servlet.forward.servlet_path"; 182 183 private static class RequestDispatcherRequestWrapper extends HttpServletRequestWrapper { 184 private final String newRequestUri; 185 186 public RequestDispatcherRequestWrapper( 187 HttpServletRequest servletRequest, String newRequestUri) { 188 super(servletRequest); 189 this.newRequestUri = newRequestUri; 190 } 191 192 @Override 193 public String getRequestURI() { 194 return newRequestUri; 195 } 196 197 @Override 198 public StringBuffer getRequestURL() { 199 StringBuffer url = new StringBuffer(); 200 String scheme = getScheme(); 201 int port = getServerPort(); 202 203 url.append(scheme); 204 url.append("://"); 205 url.append(getServerName()); 206 // port might be -1 in some cases (see java.net.URL.getPort) 207 if (port > 0 208 && (("http".equals(scheme) && (port != 80)) 209 || ("https".equals(scheme) && (port != 443)))) { 210 url.append(':'); 211 url.append(port); 212 } 213 url.append(getRequestURI()); 214 215 return (url); 216 } 217 } 218 } 219