1 /* 2 * Copyright (C) 2012 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.inject.servlet; 18 19 import static com.google.common.base.Charsets.UTF_8; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 22 import com.google.common.base.Joiner; 23 import com.google.common.base.Splitter; 24 import com.google.common.net.UrlEscapers; 25 import java.nio.charset.Charset; 26 import java.util.ArrayList; 27 import java.util.Arrays; 28 import java.util.List; 29 import javax.servlet.http.HttpServletRequest; 30 31 /** 32 * Some servlet utility methods. 33 * 34 * @author ntang@google.com (Michael Tang) 35 */ 36 final class ServletUtils { 37 private static final Splitter SLASH_SPLITTER = Splitter.on('/'); 38 private static final Joiner SLASH_JOINER = Joiner.on('/'); 39 ServletUtils()40 private ServletUtils() { 41 // private to prevent instantiation. 42 } 43 44 /** 45 * Gets the context path relative path of the URI. Returns the path of the resource relative to 46 * the context path for a request's URI, or null if no path can be extracted. 47 * 48 * <p>Also performs url decoding and normalization of the path. 49 */ 50 // @Nullable getContextRelativePath( final HttpServletRequest request)51 static String getContextRelativePath( 52 // @Nullable 53 final HttpServletRequest request) { 54 if (request != null) { 55 String contextPath = request.getContextPath(); 56 String requestURI = request.getRequestURI(); 57 if (contextPath.length() < requestURI.length()) { 58 String suffix = requestURI.substring(contextPath.length()); 59 return normalizePath(suffix); 60 } else if (requestURI.trim().length() > 0 && contextPath.length() == requestURI.length()) { 61 return "/"; 62 } 63 } 64 return null; 65 } 66 67 /** Normalizes a path by unescaping all safe, percent encoded characters. */ normalizePath(String path)68 static String normalizePath(String path) { 69 StringBuilder sb = new StringBuilder(path.length()); 70 int queryStart = path.indexOf('?'); 71 String query = null; 72 if (queryStart != -1) { 73 query = path.substring(queryStart); 74 path = path.substring(0, queryStart); 75 } 76 // Normalize the path. we need to decode path segments, normalize and rejoin in order to 77 // 1. decode and normalize safe percent escaped characters. e.g. %70 -> 'p' 78 // 2. decode and interpret dangerous character sequences. e.g. /%2E/ -> '/./' -> '/' 79 // 3. preserve dangerous encoded characters. e.g. '/%2F/' -> '///' -> '/%2F' 80 List<String> segments = new ArrayList<>(); 81 for (String segment : SLASH_SPLITTER.split(path)) { 82 // This decodes all non-special characters from the path segment. so if someone passes 83 // /%2E/foo we will normalize it to /./foo and then /foo 84 String normalized = 85 UrlEscapers.urlPathSegmentEscaper().escape(lenientDecode(segment, UTF_8, false)); 86 if (".".equals(normalized)) { 87 // skip 88 } else if ("..".equals(normalized)) { 89 if (segments.size() > 1) { 90 segments.remove(segments.size() - 1); 91 } 92 } else { 93 segments.add(normalized); 94 } 95 } 96 SLASH_JOINER.appendTo(sb, segments); 97 if (query != null) { 98 sb.append(query); 99 } 100 return sb.toString(); 101 } 102 103 104 /** 105 * Percent-decodes a US-ASCII string into a Unicode string. The specified encoding is used to 106 * determine what characters are represented by any consecutive sequences of the form 107 * "%<i>XX</i>". This is the lenient kind of decoding that will simply ignore and copy as-is any 108 * "%XX" sequence that is invalid (for example, "%HH"). 109 * 110 * @param string a percent-encoded US-ASCII string 111 * @param encoding a character encoding 112 * @param decodePlus boolean to indicate whether to decode '+' as ' ' 113 * @return a Unicode string 114 */ lenientDecode(String string, Charset encoding, boolean decodePlus)115 private static String lenientDecode(String string, Charset encoding, boolean decodePlus) { 116 117 checkNotNull(string); 118 checkNotNull(encoding); 119 120 if (decodePlus) { 121 string = string.replace('+', ' '); 122 } 123 124 int firstPercentPos = string.indexOf('%'); 125 126 if (firstPercentPos < 0) { 127 return string; 128 } 129 130 ByteAccumulator accumulator = new ByteAccumulator(string.length(), encoding); 131 StringBuilder builder = new StringBuilder(string.length()); 132 133 if (firstPercentPos > 0) { 134 builder.append(string, 0, firstPercentPos); 135 } 136 137 for (int srcPos = firstPercentPos; srcPos < string.length(); srcPos++) { 138 139 char c = string.charAt(srcPos); 140 141 if (c < 0x80) { // ASCII 142 boolean processed = false; 143 144 if (c == '%' && string.length() >= srcPos + 3) { 145 String hex = string.substring(srcPos + 1, srcPos + 3); 146 147 try { 148 int encoded = Integer.parseInt(hex, 16); 149 150 if (encoded >= 0) { 151 accumulator.append((byte) encoded); 152 srcPos += 2; 153 processed = true; 154 } 155 } catch (NumberFormatException ignore) { 156 // Expected case (badly formatted % group) 157 } 158 } 159 160 if (!processed) { 161 if (accumulator.isEmpty()) { 162 // We're not accumulating elements of a multibyte encoded 163 // char, so just toss it right into the result string. 164 165 builder.append(c); 166 } else { 167 accumulator.append((byte) c); 168 } 169 } 170 } else { // Non-ASCII 171 // A non-ASCII char marks the end of a multi-char encoding sequence, 172 // if one is in progress. 173 174 accumulator.dumpTo(builder); 175 builder.append(c); 176 } 177 } 178 179 accumulator.dumpTo(builder); 180 181 return builder.toString(); 182 } 183 184 /** Accumulates byte sequences while decoding strings, and encodes them into a StringBuilder. */ 185 private static class ByteAccumulator { 186 private byte[] bytes; 187 private int length; 188 private final Charset encoding; 189 ByteAccumulator(int capacity, Charset encoding)190 ByteAccumulator(int capacity, Charset encoding) { 191 this.bytes = new byte[Math.min(16, capacity)]; 192 this.encoding = encoding; 193 } 194 append(byte b)195 void append(byte b) { 196 ensureCapacity(length + 1); 197 bytes[length++] = b; 198 } 199 dumpTo(StringBuilder dest)200 void dumpTo(StringBuilder dest) { 201 if (length != 0) { 202 dest.append(new String(bytes, 0, length, encoding)); 203 length = 0; 204 } 205 } 206 isEmpty()207 boolean isEmpty() { 208 return length == 0; 209 } 210 ensureCapacity(int minCapacity)211 private void ensureCapacity(int minCapacity) { 212 int cap = bytes.length; 213 if (cap >= minCapacity) { 214 return; 215 } 216 int newCapacity = cap + (cap >> 1); // *1.5 217 if (newCapacity < minCapacity) { 218 // we are close to overflowing, grow by smaller steps 219 newCapacity = minCapacity; 220 } 221 // in other cases, we will naturally throw an OOM from here 222 bytes = Arrays.copyOf(bytes, newCapacity); 223 } 224 } 225 226 } 227