• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 package software.amazon.eventstream;
16 
17 import java.io.ByteArrayOutputStream;
18 import java.io.DataOutputStream;
19 import java.io.IOException;
20 import java.io.OutputStream;
21 import java.nio.ByteBuffer;
22 import java.nio.charset.StandardCharsets;
23 import java.util.Arrays;
24 import java.util.Base64;
25 import java.util.Collections;
26 import java.util.LinkedHashMap;
27 import java.util.Map;
28 import java.util.Map.Entry;
29 import java.util.zip.CRC32;
30 import java.util.zip.CheckedOutputStream;
31 import java.util.zip.Checksum;
32 
33 import static java.lang.String.format;
34 import static java.util.Objects.requireNonNull;
35 
36 /**
37  * An eventstream message.
38  */
39 public class Message {
40     private static final int TRAILING_CRC_LENGTH = 4;
41     static final int MESSAGE_OVERHEAD = Prelude.LENGTH_WITH_CRC + TRAILING_CRC_LENGTH;
42 
43     private final Map<String, HeaderValue> headers;
44     private final byte[] payload;
45 
46     /**
47      * Create a message.
48      *
49      * @param headers a non-null key-value map of headers, which will be encoded on the wire according to the
50      *     iteration order of the given map
51      * @param payload a non-null byte array containing zero or more bytes
52      */
Message(Map<String, HeaderValue> headers, byte[] payload)53     public Message(Map<String, HeaderValue> headers, byte[] payload) {
54         this.headers = requireNonNull(headers, "headers");
55         this.payload = requireNonNull(payload, "payload").clone();
56     }
57 
58     /**
59      * Returns this message's headers. If this message was obtained through {@link #decode(ByteBuffer)}, the iteration
60      * order of the headers is guaranteed to match the order in which the headers appeared on the wire.
61      *
62      * @return an unmodifiable non-null key-value map of this message's headers
63      */
getHeaders()64     public Map<String, HeaderValue> getHeaders() {
65         return headers;
66     }
67 
68     /**
69      * Returns this message's payload.
70      *
71      * @return a non-null array of zero or more bytes
72      */
getPayload()73     public byte[] getPayload() {
74         return payload.clone();
75     }
76 
decode(ByteBuffer buf)77     public static Message decode(ByteBuffer buf) {
78         return decode(Prelude.decode(buf), buf);
79     }
80 
81     /**
82      * Decodes a message with an already decoded prelude. Useful for not decoding the prelude twice.
83      *
84      * @param prelude Decoded prelude of message.
85      * @param buf Data of message (including prelude which will be skipped over).
86      * @return Decoded message
87      */
decode(Prelude prelude, ByteBuffer buf)88     static Message decode(Prelude prelude, ByteBuffer buf) {
89         int totalLength = prelude.getTotalLength();
90         validateMessageCrc(buf, totalLength);
91         buf.position(buf.position() + Prelude.LENGTH_WITH_CRC);
92 
93         long headersLength = prelude.getHeadersLength();
94         byte[] headerBytes = new byte[Math.toIntExact(headersLength)];
95         buf.get(headerBytes);
96         Map<String, HeaderValue> headers = decodeHeaders(ByteBuffer.wrap(headerBytes));
97 
98         byte[] payload = new byte[Math.toIntExact(totalLength - MESSAGE_OVERHEAD - headersLength)];
99         buf.get(payload);
100         buf.getInt(); // skip past the message CRC
101 
102         return new Message(headers, payload);
103     }
104 
validateMessageCrc(ByteBuffer buf, int totalLength)105     private static void validateMessageCrc(ByteBuffer buf, int totalLength) {
106         Checksum crc = new CRC32();
107 
108         Checksums.update(crc, (ByteBuffer) buf.duplicate().limit(buf.position() + totalLength - 4));
109         long computedMessageCrc = crc.getValue();
110 
111         long wireMessageCrc = Integer.toUnsignedLong(buf.getInt(buf.position() + totalLength - 4));
112 
113         if (wireMessageCrc != computedMessageCrc) {
114             throw new IllegalArgumentException(format("Message checksum failure: expected 0x%x, computed 0x%x",
115                 wireMessageCrc, computedMessageCrc));
116         }
117     }
118 
decodeHeaders(ByteBuffer buf)119     static Map<String, HeaderValue> decodeHeaders(ByteBuffer buf) {
120         Map<String, HeaderValue> headers = new LinkedHashMap<>();
121 
122         while (buf.hasRemaining()) {
123             Header header = Header.decode(buf);
124             headers.put(header.getName(), header.getValue());
125         }
126 
127         return Collections.unmodifiableMap(headers);
128     }
129 
toByteBuffer()130     public ByteBuffer toByteBuffer() {
131         try {
132             ByteArrayOutputStream baos = new ByteArrayOutputStream();
133             encode(baos);
134             baos.close();
135             return ByteBuffer.wrap(baos.toByteArray());
136         } catch (IOException e) {
137             throw new RuntimeException(e);
138         }
139     }
140 
encode(OutputStream os)141     public void encode(OutputStream os) {
142         try {
143             CheckedOutputStream checkedOutputStream = new CheckedOutputStream(os, new CRC32());
144             encodeOrThrow(checkedOutputStream);
145             long messageCrc = checkedOutputStream.getChecksum().getValue();
146             os.write((int) (0xFF & messageCrc >> 24));
147             os.write((int) (0xFF & messageCrc >> 16));
148             os.write((int) (0xFF & messageCrc >> 8));
149             os.write((int) (0xFF & messageCrc));
150 
151             os.flush();
152         } catch (IOException ex) {
153             throw new RuntimeException(ex);
154         }
155     }
156 
157     /**
158      * Encode the given {@code headers}, without any leading or trailing metadata such as checksums or lengths.
159      *
160      * @param headers a sequence of zero or more headers, which will be encoded in iteration order
161      * @return a byte array corresponding to the {@code headers} section of a {@code Message}
162      */
encodeHeaders(Iterable<Entry<String, HeaderValue>> headers)163     public static byte[] encodeHeaders(Iterable<Entry<String, HeaderValue>> headers) {
164         try {
165             ByteArrayOutputStream baos = new ByteArrayOutputStream();
166             DataOutputStream dos = new DataOutputStream(baos);
167             for (Entry<String, HeaderValue> entry : headers) {
168                 Header.encode(entry, dos);
169             }
170             dos.close();
171             return baos.toByteArray();
172         } catch (IOException e) {
173             throw new RuntimeException(e);
174         }
175     }
176 
encodeOrThrow(OutputStream os)177     private void encodeOrThrow(OutputStream os) throws IOException {
178         ByteArrayOutputStream headersAndPayload = new ByteArrayOutputStream();
179         headersAndPayload.write(encodeHeaders(headers.entrySet()));
180         headersAndPayload.write(payload);
181 
182         int totalLength = Prelude.LENGTH_WITH_CRC + headersAndPayload.size() + 4;
183 
184         byte[] preludeBytes = getPrelude(totalLength);
185         Checksum crc = new CRC32();
186         crc.update(preludeBytes, 0, preludeBytes.length);
187 
188         DataOutputStream dos = new DataOutputStream(os);
189         dos.write(preludeBytes);
190         dos.writeInt((int) crc.getValue());
191         dos.flush();
192 
193         headersAndPayload.writeTo(os);
194     }
195 
getPrelude(int totalLength)196     private byte[] getPrelude(int totalLength) throws IOException {
197         ByteArrayOutputStream baos = new ByteArrayOutputStream(8);
198         DataOutputStream dos = new DataOutputStream(baos);
199 
200         int headerLength = totalLength - Message.MESSAGE_OVERHEAD - payload.length;
201         dos.writeInt(totalLength);
202         dos.writeInt(headerLength);
203 
204         dos.close();
205         return baos.toByteArray();
206     }
207 
208     @Override
equals(Object o)209     public boolean equals(Object o) {
210         if (this == o) return true;
211         if (o == null || getClass() != o.getClass()) return false;
212 
213         Message message = (Message) o;
214 
215         if (!headers.equals(message.headers)) return false;
216         return Arrays.equals(payload, message.payload);
217     }
218 
219     @Override
hashCode()220     public int hashCode() {
221         int result = headers.hashCode();
222         result = 31 * result + Arrays.hashCode(payload);
223         return result;
224     }
225 
226     @Override
toString()227     public String toString() {
228         StringBuilder ret = new StringBuilder();
229 
230         for (Entry<String, HeaderValue> entry : headers.entrySet()) {
231             ret.append(entry.getKey());
232             ret.append(": ");
233             ret.append(entry.getValue().toString());
234             ret.append('\n');
235         }
236         ret.append('\n');
237 
238         String contentType = headers.getOrDefault(":content-type", HeaderValue.fromString("application/octet-stream"))
239             .getString();
240         if (contentType.contains("json") || contentType.contains("text")) {
241             ret.append(new String(payload, StandardCharsets.UTF_8));
242         } else {
243             ret.append(Base64.getEncoder().encodeToString(payload));
244         }
245         ret.append('\n');
246         return ret.toString();
247     }
248 }
249