1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import com.google.common.base.Preconditions; 20 import java.nio.Buffer; 21 import java.nio.ByteBuffer; 22 import java.nio.ByteOrder; 23 import java.security.GeneralSecurityException; 24 25 /** Framing and deframing methods and classes used by handshaker. */ 26 public final class AltsFraming { 27 // The size of the frame field. Must correspond to the size of int, 4 bytes. 28 // Left package-private for testing. 29 private static final int FRAME_LENGTH_HEADER_SIZE = 4; 30 private static final int FRAME_MESSAGE_TYPE_HEADER_SIZE = 4; 31 private static final int MAX_DATA_LENGTH = 1024 * 1024; 32 private static final int INITIAL_BUFFER_CAPACITY = 1024 * 64; 33 34 // TODO: Make this the responsibility of the caller. 35 private static final int MESSAGE_TYPE = 6; 36 AltsFraming()37 private AltsFraming() {} 38 getFrameLengthHeaderSize()39 static int getFrameLengthHeaderSize() { 40 return FRAME_LENGTH_HEADER_SIZE; 41 } 42 getFrameMessageTypeHeaderSize()43 static int getFrameMessageTypeHeaderSize() { 44 return FRAME_MESSAGE_TYPE_HEADER_SIZE; 45 } 46 getMaxDataLength()47 static int getMaxDataLength() { 48 return MAX_DATA_LENGTH; 49 } 50 getFramingOverhead()51 static int getFramingOverhead() { 52 return FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE; 53 } 54 55 /** 56 * Creates a frame of length dataSize + FRAME_HEADER_SIZE using the input bytes, if dataSize <= 57 * input.remaining(). Otherwise, a frame of length input.remaining() + FRAME_HEADER_SIZE is 58 * created. 59 */ toFrame(ByteBuffer input, int dataSize)60 static ByteBuffer toFrame(ByteBuffer input, int dataSize) throws GeneralSecurityException { 61 Preconditions.checkNotNull(input); 62 if (dataSize > input.remaining()) { 63 dataSize = input.remaining(); 64 } 65 Producer producer = new Producer(); 66 ByteBuffer inputAlias = input.duplicate(); 67 ((Buffer) inputAlias).limit(input.position() + dataSize); 68 producer.readBytes(inputAlias); 69 producer.flush(); 70 ((Buffer) input).position(inputAlias.position()); 71 ByteBuffer output = producer.getRawFrame(); 72 return output; 73 } 74 75 /** 76 * A helper class to write a frame. 77 * 78 * <p>This class guarantees that one of the following is true: 79 * 80 * <ul> 81 * <li>readBytes will read from the input 82 * <li>writeBytes will write to the output 83 * </ul> 84 * 85 * <p>Sample usage: 86 * 87 * <pre>{@code 88 * Producer producer = new Producer(); 89 * ByteBuffer inputBuffer = readBytesFromMyStream(); 90 * ByteBuffer outputBuffer = writeBytesToMyStream(); 91 * while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) { 92 * producer.readBytes(inputBuffer); 93 * producer.writeBytes(outputBuffer); 94 * } 95 * }</pre> 96 * 97 * <p>Alternatively, this class guarantees that one of the following is true: 98 * 99 * <ul> 100 * <li>readBytes will read from the input 101 * <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a 102 * processed frame. 103 * </ul> 104 * 105 * <p>Sample usage: 106 * 107 * <pre>{@code 108 * Producer producer = new Producer(); 109 * while (!producer.isComplete()) { 110 * ByteBuffer inputBuffer = readBytesFromMyStream(); 111 * producer.readBytes(inputBuffer); 112 * } 113 * producer.flush(); 114 * ByteBuffer outputBuffer = producer.getRawFrame(); 115 * }</pre> 116 */ 117 static final class Producer { 118 private ByteBuffer buffer; 119 private boolean isComplete; 120 Producer(int maxFrameSize)121 Producer(int maxFrameSize) { 122 buffer = ByteBuffer.allocate(maxFrameSize); 123 reset(); 124 Preconditions.checkArgument(maxFrameSize > getFramePrefixLength() + getFrameSuffixLength()); 125 } 126 Producer()127 Producer() { 128 this(INITIAL_BUFFER_CAPACITY); 129 } 130 131 /** The length of the frame prefix data, including the message length/type fields. */ getFramePrefixLength()132 int getFramePrefixLength() { 133 int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE; 134 return result; 135 } 136 getFrameSuffixLength()137 int getFrameSuffixLength() { 138 return 0; 139 } 140 141 /** 142 * Reads bytes from input, parsing them into a frame. Returns false if and only if more data is 143 * needed. To obtain a full frame this method must be called repeatedly until it returns true. 144 */ readBytes(ByteBuffer input)145 boolean readBytes(ByteBuffer input) throws GeneralSecurityException { 146 Preconditions.checkNotNull(input); 147 if (isComplete) { 148 return true; 149 } 150 copy(buffer, input); 151 if (!buffer.hasRemaining()) { 152 flush(); 153 } 154 return isComplete; 155 } 156 157 /** 158 * Completes the current frame, signaling that no further data is available to be passed to 159 * readBytes and that the client requires writeBytes to start returning data. isComplete() is 160 * guaranteed to return true after this call. 161 */ flush()162 void flush() throws GeneralSecurityException { 163 if (isComplete) { 164 return; 165 } 166 // Get the length of the complete frame. 167 int frameLength = buffer.position() + getFrameSuffixLength(); 168 169 // Set the limit and move to the start. 170 ((Buffer) buffer).flip(); 171 172 // Advance the limit to allow a crypto suffix. 173 ((Buffer) buffer).limit(buffer.limit() + getFrameSuffixLength()); 174 175 // Write the data length and the message type. 176 int dataLength = frameLength - FRAME_LENGTH_HEADER_SIZE; 177 buffer.order(ByteOrder.LITTLE_ENDIAN); 178 buffer.putInt(dataLength); 179 buffer.putInt(MESSAGE_TYPE); 180 181 // Move the position back to 0, the frame is ready. 182 ((Buffer) buffer).position(0); 183 isComplete = true; 184 } 185 186 /** Resets the state, preparing to construct a new frame. Must be called between frames. */ reset()187 private void reset() { 188 ((Buffer) buffer).clear(); 189 190 // Save some space for framing, we'll fill that in later. 191 ((Buffer) buffer).position(getFramePrefixLength()); 192 ((Buffer) buffer).limit(buffer.limit() - getFrameSuffixLength()); 193 194 isComplete = false; 195 } 196 197 /** 198 * Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be 199 * called when isComplete() returns true, otherwise null is returned. The returned object 200 * aliases the internal buffer, that is, it shares memory with the internal buffer. No further 201 * operations are permitted on this object until the caller has processed the data it needs from 202 * the returned byte buffer. 203 */ getRawFrame()204 ByteBuffer getRawFrame() { 205 if (!isComplete) { 206 return null; 207 } 208 ByteBuffer result = buffer.duplicate(); 209 reset(); 210 return result; 211 } 212 } 213 214 /** 215 * A helper class to read a frame. 216 * 217 * <p>This class guarantees that one of the following is true: 218 * 219 * <ul> 220 * <li>readBytes will read from the input 221 * <li>writeBytes will write to the output 222 * </ul> 223 * 224 * <p>Sample usage: 225 * 226 * <pre>{@code 227 * Parser parser = new Parser(); 228 * ByteBuffer inputBuffer = readBytesFromMyStream(); 229 * ByteBuffer outputBuffer = writeBytesToMyStream(); 230 * while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) { 231 * parser.readBytes(inputBuffer); 232 * parser.writeBytes(outputBuffer); } 233 * }</pre> 234 * 235 * <p>Alternatively, this class guarantees that one of the following is true: 236 * 237 * <ul> 238 * <li>readBytes will read from the input 239 * <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a 240 * processed frame. 241 * </ul> 242 * 243 * <p>Sample usage: 244 * 245 * <pre>{@code 246 * Parser parser = new Parser(); 247 * while (!parser.isComplete()) { 248 * ByteBuffer inputBuffer = readBytesFromMyStream(); 249 * parser.readBytes(inputBuffer); 250 * } 251 * ByteBuffer outputBuffer = parser.getRawFrame(); 252 * }</pre> 253 */ 254 public static final class Parser { 255 private ByteBuffer buffer = ByteBuffer.allocate(INITIAL_BUFFER_CAPACITY); 256 private boolean isComplete = false; 257 Parser()258 public Parser() { 259 Preconditions.checkArgument( 260 INITIAL_BUFFER_CAPACITY > getFramePrefixLength() + getFrameSuffixLength()); 261 } 262 263 /** 264 * Reads bytes from input, parsing them into a frame. Returns false if and only if more data is 265 * needed. To obtain a full frame this method must be called repeatedly until it returns true. 266 */ readBytes(ByteBuffer input)267 public boolean readBytes(ByteBuffer input) throws GeneralSecurityException { 268 Preconditions.checkNotNull(input); 269 270 if (isComplete) { 271 return true; 272 } 273 274 // Read enough bytes to determine the length 275 while (buffer.position() < FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) { 276 buffer.put(input.get()); 277 } 278 279 // If we have enough bytes to determine the length, read the length and ensure that our 280 // internal buffer is large enough. 281 if (buffer.position() == FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) { 282 ByteBuffer bufferAlias = buffer.duplicate(); 283 ((Buffer) bufferAlias).flip(); 284 bufferAlias.order(ByteOrder.LITTLE_ENDIAN); 285 int dataLength = bufferAlias.getInt(); 286 if (dataLength < FRAME_MESSAGE_TYPE_HEADER_SIZE || dataLength > MAX_DATA_LENGTH) { 287 throw new IllegalArgumentException("Invalid frame length " + dataLength); 288 } 289 // Maybe resize the buffer 290 int frameLength = dataLength + FRAME_LENGTH_HEADER_SIZE; 291 if (buffer.capacity() < frameLength) { 292 buffer = ByteBuffer.allocate(frameLength); 293 buffer.order(ByteOrder.LITTLE_ENDIAN); 294 buffer.putInt(dataLength); 295 } 296 ((Buffer) buffer).limit(frameLength); 297 } 298 299 // TODO: Similarly extract and check message type. 300 301 // Read the remaining data into the internal buffer. 302 copy(buffer, input); 303 if (!buffer.hasRemaining()) { 304 ((Buffer) buffer).flip(); 305 isComplete = true; 306 } 307 return isComplete; 308 } 309 310 /** The length of the frame prefix data, including the message length/type fields. */ getFramePrefixLength()311 int getFramePrefixLength() { 312 int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE; 313 return result; 314 } 315 getFrameSuffixLength()316 int getFrameSuffixLength() { 317 return 0; 318 } 319 320 /** Returns true if we've parsed a complete frame. */ isComplete()321 public boolean isComplete() { 322 return isComplete; 323 } 324 325 /** Resets the state, preparing to parse a new frame. Must be called between frames. */ reset()326 private void reset() { 327 ((Buffer) buffer).clear(); 328 isComplete = false; 329 } 330 331 /** 332 * Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be 333 * called when isComplete() returns true, otherwise null is returned. The returned object 334 * aliases the internal buffer, that is, it shares memory with the internal buffer. No further 335 * operations are permitted on this object until the caller has processed the data it needs from 336 * the returned byte buffer. 337 */ getRawFrame()338 public ByteBuffer getRawFrame() { 339 if (!isComplete) { 340 return null; 341 } 342 ByteBuffer result = buffer.duplicate(); 343 reset(); 344 return result; 345 } 346 } 347 348 /** 349 * Copy as much as possible to dst from src. Unlike {@link ByteBuffer#put(ByteBuffer)}, this stops 350 * early if there is no room left in dst. 351 */ copy(ByteBuffer dst, ByteBuffer src)352 private static void copy(ByteBuffer dst, ByteBuffer src) { 353 if (dst.hasRemaining() && src.hasRemaining()) { 354 // Avoid an allocation if possible. 355 if (dst.remaining() >= src.remaining()) { 356 dst.put(src); 357 } else { 358 int count = Math.min(dst.remaining(), src.remaining()); 359 ByteBuffer slice = src.slice(); 360 ((Buffer) slice).limit(count); 361 dst.put(slice); 362 ((Buffer) src).position(src.position() + count); 363 } 364 } 365 } 366 } 367