1 /* 2 * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 package software.amazon.eventstream; 16 17 import java.io.ByteArrayOutputStream; 18 import java.io.DataOutputStream; 19 import java.io.IOException; 20 import java.io.OutputStream; 21 import java.nio.ByteBuffer; 22 import java.nio.charset.StandardCharsets; 23 import java.util.Arrays; 24 import java.util.Base64; 25 import java.util.Collections; 26 import java.util.LinkedHashMap; 27 import java.util.Map; 28 import java.util.Map.Entry; 29 import java.util.zip.CRC32; 30 import java.util.zip.CheckedOutputStream; 31 import java.util.zip.Checksum; 32 33 import static java.lang.String.format; 34 import static java.util.Objects.requireNonNull; 35 36 /** 37 * An eventstream message. 38 */ 39 public class Message { 40 private static final int TRAILING_CRC_LENGTH = 4; 41 static final int MESSAGE_OVERHEAD = Prelude.LENGTH_WITH_CRC + TRAILING_CRC_LENGTH; 42 43 private final Map<String, HeaderValue> headers; 44 private final byte[] payload; 45 46 /** 47 * Create a message. 48 * 49 * @param headers a non-null key-value map of headers, which will be encoded on the wire according to the 50 * iteration order of the given map 51 * @param payload a non-null byte array containing zero or more bytes 52 */ Message(Map<String, HeaderValue> headers, byte[] payload)53 public Message(Map<String, HeaderValue> headers, byte[] payload) { 54 this.headers = requireNonNull(headers, "headers"); 55 this.payload = requireNonNull(payload, "payload").clone(); 56 } 57 58 /** 59 * Returns this message's headers. If this message was obtained through {@link #decode(ByteBuffer)}, the iteration 60 * order of the headers is guaranteed to match the order in which the headers appeared on the wire. 61 * 62 * @return an unmodifiable non-null key-value map of this message's headers 63 */ getHeaders()64 public Map<String, HeaderValue> getHeaders() { 65 return headers; 66 } 67 68 /** 69 * Returns this message's payload. 70 * 71 * @return a non-null array of zero or more bytes 72 */ getPayload()73 public byte[] getPayload() { 74 return payload.clone(); 75 } 76 decode(ByteBuffer buf)77 public static Message decode(ByteBuffer buf) { 78 return decode(Prelude.decode(buf), buf); 79 } 80 81 /** 82 * Decodes a message with an already decoded prelude. Useful for not decoding the prelude twice. 83 * 84 * @param prelude Decoded prelude of message. 85 * @param buf Data of message (including prelude which will be skipped over). 86 * @return Decoded message 87 */ decode(Prelude prelude, ByteBuffer buf)88 static Message decode(Prelude prelude, ByteBuffer buf) { 89 int totalLength = prelude.getTotalLength(); 90 validateMessageCrc(buf, totalLength); 91 buf.position(buf.position() + Prelude.LENGTH_WITH_CRC); 92 93 long headersLength = prelude.getHeadersLength(); 94 byte[] headerBytes = new byte[Math.toIntExact(headersLength)]; 95 buf.get(headerBytes); 96 Map<String, HeaderValue> headers = decodeHeaders(ByteBuffer.wrap(headerBytes)); 97 98 byte[] payload = new byte[Math.toIntExact(totalLength - MESSAGE_OVERHEAD - headersLength)]; 99 buf.get(payload); 100 buf.getInt(); // skip past the message CRC 101 102 return new Message(headers, payload); 103 } 104 validateMessageCrc(ByteBuffer buf, int totalLength)105 private static void validateMessageCrc(ByteBuffer buf, int totalLength) { 106 Checksum crc = new CRC32(); 107 108 Checksums.update(crc, (ByteBuffer) buf.duplicate().limit(buf.position() + totalLength - 4)); 109 long computedMessageCrc = crc.getValue(); 110 111 long wireMessageCrc = Integer.toUnsignedLong(buf.getInt(buf.position() + totalLength - 4)); 112 113 if (wireMessageCrc != computedMessageCrc) { 114 throw new IllegalArgumentException(format("Message checksum failure: expected 0x%x, computed 0x%x", 115 wireMessageCrc, computedMessageCrc)); 116 } 117 } 118 decodeHeaders(ByteBuffer buf)119 static Map<String, HeaderValue> decodeHeaders(ByteBuffer buf) { 120 Map<String, HeaderValue> headers = new LinkedHashMap<>(); 121 122 while (buf.hasRemaining()) { 123 Header header = Header.decode(buf); 124 headers.put(header.getName(), header.getValue()); 125 } 126 127 return Collections.unmodifiableMap(headers); 128 } 129 toByteBuffer()130 public ByteBuffer toByteBuffer() { 131 try { 132 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 133 encode(baos); 134 baos.close(); 135 return ByteBuffer.wrap(baos.toByteArray()); 136 } catch (IOException e) { 137 throw new RuntimeException(e); 138 } 139 } 140 encode(OutputStream os)141 public void encode(OutputStream os) { 142 try { 143 CheckedOutputStream checkedOutputStream = new CheckedOutputStream(os, new CRC32()); 144 encodeOrThrow(checkedOutputStream); 145 long messageCrc = checkedOutputStream.getChecksum().getValue(); 146 os.write((int) (0xFF & messageCrc >> 24)); 147 os.write((int) (0xFF & messageCrc >> 16)); 148 os.write((int) (0xFF & messageCrc >> 8)); 149 os.write((int) (0xFF & messageCrc)); 150 151 os.flush(); 152 } catch (IOException ex) { 153 throw new RuntimeException(ex); 154 } 155 } 156 157 /** 158 * Encode the given {@code headers}, without any leading or trailing metadata such as checksums or lengths. 159 * 160 * @param headers a sequence of zero or more headers, which will be encoded in iteration order 161 * @return a byte array corresponding to the {@code headers} section of a {@code Message} 162 */ encodeHeaders(Iterable<Entry<String, HeaderValue>> headers)163 public static byte[] encodeHeaders(Iterable<Entry<String, HeaderValue>> headers) { 164 try { 165 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 166 DataOutputStream dos = new DataOutputStream(baos); 167 for (Entry<String, HeaderValue> entry : headers) { 168 Header.encode(entry, dos); 169 } 170 dos.close(); 171 return baos.toByteArray(); 172 } catch (IOException e) { 173 throw new RuntimeException(e); 174 } 175 } 176 encodeOrThrow(OutputStream os)177 private void encodeOrThrow(OutputStream os) throws IOException { 178 ByteArrayOutputStream headersAndPayload = new ByteArrayOutputStream(); 179 headersAndPayload.write(encodeHeaders(headers.entrySet())); 180 headersAndPayload.write(payload); 181 182 int totalLength = Prelude.LENGTH_WITH_CRC + headersAndPayload.size() + 4; 183 184 byte[] preludeBytes = getPrelude(totalLength); 185 Checksum crc = new CRC32(); 186 crc.update(preludeBytes, 0, preludeBytes.length); 187 188 DataOutputStream dos = new DataOutputStream(os); 189 dos.write(preludeBytes); 190 dos.writeInt((int) crc.getValue()); 191 dos.flush(); 192 193 headersAndPayload.writeTo(os); 194 } 195 getPrelude(int totalLength)196 private byte[] getPrelude(int totalLength) throws IOException { 197 ByteArrayOutputStream baos = new ByteArrayOutputStream(8); 198 DataOutputStream dos = new DataOutputStream(baos); 199 200 int headerLength = totalLength - Message.MESSAGE_OVERHEAD - payload.length; 201 dos.writeInt(totalLength); 202 dos.writeInt(headerLength); 203 204 dos.close(); 205 return baos.toByteArray(); 206 } 207 208 @Override equals(Object o)209 public boolean equals(Object o) { 210 if (this == o) return true; 211 if (o == null || getClass() != o.getClass()) return false; 212 213 Message message = (Message) o; 214 215 if (!headers.equals(message.headers)) return false; 216 return Arrays.equals(payload, message.payload); 217 } 218 219 @Override hashCode()220 public int hashCode() { 221 int result = headers.hashCode(); 222 result = 31 * result + Arrays.hashCode(payload); 223 return result; 224 } 225 226 @Override toString()227 public String toString() { 228 StringBuilder ret = new StringBuilder(); 229 230 for (Entry<String, HeaderValue> entry : headers.entrySet()) { 231 ret.append(entry.getKey()); 232 ret.append(": "); 233 ret.append(entry.getValue().toString()); 234 ret.append('\n'); 235 } 236 ret.append('\n'); 237 238 String contentType = headers.getOrDefault(":content-type", HeaderValue.fromString("application/octet-stream")) 239 .getString(); 240 if (contentType.contains("json") || contentType.contains("text")) { 241 ret.append(new String(payload, StandardCharsets.UTF_8)); 242 } else { 243 ret.append(Base64.getEncoder().encodeToString(payload)); 244 } 245 ret.append('\n'); 246 return ret.toString(); 247 } 248 } 249