• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.services.sqs.internal;
17 
18 import java.nio.ByteBuffer;
19 import java.nio.charset.StandardCharsets;
20 import java.security.MessageDigest;
21 import java.util.ArrayList;
22 import java.util.Collections;
23 import java.util.HashMap;
24 import java.util.List;
25 import java.util.Map;
26 import software.amazon.awssdk.annotations.SdkInternalApi;
27 import software.amazon.awssdk.core.SdkBytes;
28 import software.amazon.awssdk.core.SdkRequest;
29 import software.amazon.awssdk.core.SdkResponse;
30 import software.amazon.awssdk.core.exception.SdkClientException;
31 import software.amazon.awssdk.core.interceptor.Context;
32 import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
33 import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
34 import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
35 import software.amazon.awssdk.services.sqs.endpoints.SqsClientContextParams;
36 import software.amazon.awssdk.services.sqs.model.Message;
37 import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
38 import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
39 import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
40 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
41 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
42 import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
43 import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry;
44 import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
45 import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
46 import software.amazon.awssdk.utils.AttributeMap;
47 import software.amazon.awssdk.utils.BinaryUtils;
48 import software.amazon.awssdk.utils.Logger;
49 import software.amazon.awssdk.utils.Md5Utils;
50 
51 /**
52  * SQS operations on sending and receiving messages will return the MD5 digest of the message body.
53  * This custom request handler will verify that the message is correctly received by SQS, by
54  * comparing the returned MD5 with the calculation according to the original request.
55  */
56 @SdkInternalApi
57 public final class MessageMD5ChecksumInterceptor implements ExecutionInterceptor {
58 
59     private static final int INTEGER_SIZE_IN_BYTES = 4;
60     private static final byte STRING_TYPE_FIELD_INDEX = 1;
61     private static final byte BINARY_TYPE_FIELD_INDEX = 2;
62     private static final byte STRING_LIST_TYPE_FIELD_INDEX = 3;
63     private static final byte BINARY_LIST_TYPE_FIELD_INDEX = 4;
64 
65     /*
66      * Constant strings for composing error message.
67      */
68     private static final String MD5_MISMATCH_ERROR_MESSAGE =
69             "MD5 returned by SQS does not match the calculation on the original request. " +
70             "(MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
71     private static final String MD5_MISMATCH_ERROR_MESSAGE_WITH_ID =
72             "MD5 returned by SQS does not match the calculation on the original request. " +
73             "(Message ID: %s, MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
74     private static final String MESSAGE_BODY = "message body";
75     private static final String MESSAGE_ATTRIBUTES = "message attributes";
76 
77     private static final Logger log = Logger.loggerFor(MessageMD5ChecksumInterceptor.class);
78 
79     @Override
afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes)80     public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
81         SdkResponse response = context.response();
82         SdkRequest originalRequest = context.request();
83 
84         if (response != null && validateMessageMD5Enabled(executionAttributes)) {
85             if (originalRequest instanceof SendMessageRequest) {
86                 SendMessageRequest sendMessageRequest = (SendMessageRequest) originalRequest;
87                 SendMessageResponse sendMessageResult = (SendMessageResponse) response;
88                 sendMessageOperationMd5Check(sendMessageRequest, sendMessageResult);
89 
90             } else if (originalRequest instanceof ReceiveMessageRequest) {
91                 ReceiveMessageResponse receiveMessageResult = (ReceiveMessageResponse) response;
92                 receiveMessageResultMd5Check(receiveMessageResult);
93 
94             } else if (originalRequest instanceof SendMessageBatchRequest) {
95                 SendMessageBatchRequest sendMessageBatchRequest = (SendMessageBatchRequest) originalRequest;
96                 SendMessageBatchResponse sendMessageBatchResult = (SendMessageBatchResponse) response;
97                 sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResult);
98             }
99         }
100     }
101 
validateMessageMD5Enabled(ExecutionAttributes executionAttributes)102     private static boolean validateMessageMD5Enabled(ExecutionAttributes executionAttributes) {
103         AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS);
104         Boolean enableMd5Validation = clientContextParams.get(SqsClientContextParams.CHECKSUM_VALIDATION_ENABLED);
105         return enableMd5Validation == null || enableMd5Validation;
106     }
107 
108     /**
109      * Throw an exception if the MD5 checksums returned in the SendMessageResponse do not match the
110      * client-side calculation based on the original message in the SendMessageRequest.
111      */
sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest, SendMessageResponse sendMessageResult)112     private static void sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest,
113                                                      SendMessageResponse sendMessageResult) {
114         String messageBodySent = sendMessageRequest.messageBody();
115         String bodyMd5Returned = sendMessageResult.md5OfMessageBody();
116         String clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent);
117         if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
118             throw SdkClientException.builder()
119                                     .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5,
120                                                           bodyMd5Returned))
121                                     .build();
122         }
123 
124         Map<String, MessageAttributeValue> messageAttrSent = sendMessageRequest.messageAttributes();
125         if (messageAttrSent != null && !messageAttrSent.isEmpty()) {
126             String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent);
127             String attrMd5Returned = sendMessageResult.md5OfMessageAttributes();
128             if (!clientSideAttrMd5.equals(attrMd5Returned)) {
129                 throw SdkClientException.builder()
130                                         .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
131                                                               clientSideAttrMd5, attrMd5Returned))
132                                         .build();
133             }
134         }
135     }
136 
137     /**
138      * Throw an exception if the MD5 checksums included in the ReceiveMessageResponse do not match the
139      * client-side calculation on the received messages.
140      */
receiveMessageResultMd5Check(ReceiveMessageResponse receiveMessageResult)141     private static void receiveMessageResultMd5Check(ReceiveMessageResponse receiveMessageResult) {
142         if (receiveMessageResult.messages() != null) {
143             for (Message messageReceived : receiveMessageResult.messages()) {
144                 String messageBody = messageReceived.body();
145                 String bodyMd5Returned = messageReceived.md5OfBody();
146                 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
147                 if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
148                     throw SdkClientException.builder()
149                                             .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY,
150                                                                   clientSideBodyMd5, bodyMd5Returned))
151                                             .build();
152                 }
153 
154                 Map<String, MessageAttributeValue> messageAttr = messageReceived.messageAttributes();
155                 if (messageAttr != null && !messageAttr.isEmpty()) {
156                     String attrMd5Returned = messageReceived.md5OfMessageAttributes();
157                     String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
158                     if (!clientSideAttrMd5.equals(attrMd5Returned)) {
159                         throw SdkClientException.builder()
160                                                 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
161                                                                       clientSideAttrMd5, attrMd5Returned))
162                                                 .build();
163                     }
164                 }
165             }
166         }
167     }
168 
169     /**
170      * Throw an exception if the MD5 checksums returned in the SendMessageBatchResponse do not match
171      * the client-side calculation based on the original messages in the SendMessageBatchRequest.
172      */
sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest, SendMessageBatchResponse sendMessageBatchResult)173     private static void sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest,
174                                                           SendMessageBatchResponse sendMessageBatchResult) {
175         Map<String, SendMessageBatchRequestEntry> idToRequestEntryMap = new HashMap<>();
176         if (sendMessageBatchRequest.entries() != null) {
177             for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) {
178                 idToRequestEntryMap.put(entry.id(), entry);
179             }
180         }
181 
182         if (sendMessageBatchResult.successful() != null) {
183             for (SendMessageBatchResultEntry entry : sendMessageBatchResult.successful()) {
184                 String messageBody = idToRequestEntryMap.get(entry.id()).messageBody();
185                 String bodyMd5Returned = entry.md5OfMessageBody();
186                 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
187                 if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
188                     throw SdkClientException.builder()
189                                             .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_BODY,
190                                                                   entry.id(), clientSideBodyMd5, bodyMd5Returned))
191                                             .build();
192                 }
193 
194                 Map<String, MessageAttributeValue> messageAttr = idToRequestEntryMap.get(entry.id())
195                                                                                     .messageAttributes();
196                 if (messageAttr != null && !messageAttr.isEmpty()) {
197                     String attrMd5Returned = entry.md5OfMessageAttributes();
198                     String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
199                     if (!clientSideAttrMd5.equals(attrMd5Returned)) {
200                         throw SdkClientException.builder()
201                                                 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID,
202                                                                       MESSAGE_ATTRIBUTES, entry.id(), clientSideAttrMd5,
203                                                                       attrMd5Returned))
204                                                 .build();
205                     }
206                 }
207             }
208         }
209     }
210 
211     /**
212      * Returns the hex-encoded MD5 hash String of the given message body.
213      */
calculateMessageBodyMd5(String messageBody)214     private static String calculateMessageBodyMd5(String messageBody) {
215         log.debug(() -> "Message body: " + messageBody);
216         byte[] expectedMd5;
217         try {
218             expectedMd5 = Md5Utils.computeMD5Hash(messageBody.getBytes(StandardCharsets.UTF_8));
219         } catch (Exception e) {
220             throw SdkClientException.builder()
221                                     .message("Unable to calculate the MD5 hash of the message body. " + e.getMessage())
222                                     .cause(e)
223                                     .build();
224         }
225         String expectedMd5Hex = BinaryUtils.toHex(expectedMd5);
226         log.debug(() -> "Expected  MD5 of message body: " + expectedMd5Hex);
227         return expectedMd5Hex;
228     }
229 
230     /**
231      * Returns the hex-encoded MD5 hash String of the given message attributes.
232      */
calculateMessageAttributesMd5(final Map<String, MessageAttributeValue> messageAttributes)233     private static String calculateMessageAttributesMd5(final Map<String, MessageAttributeValue> messageAttributes) {
234         log.debug(() -> "Message attributes: " + messageAttributes);
235         List<String> sortedAttributeNames = new ArrayList<>(messageAttributes.keySet());
236         Collections.sort(sortedAttributeNames);
237 
238         MessageDigest md5Digest;
239         try {
240             md5Digest = MessageDigest.getInstance("MD5");
241 
242             for (String attrName : sortedAttributeNames) {
243                 MessageAttributeValue attrValue = messageAttributes.get(attrName);
244 
245                 // Encoded Name
246                 updateLengthAndBytes(md5Digest, attrName);
247 
248                 // Encoded Type
249                 updateLengthAndBytes(md5Digest, attrValue.dataType());
250 
251                 // Encoded Value
252                 if (attrValue.stringValue() != null) {
253                     md5Digest.update(STRING_TYPE_FIELD_INDEX);
254                     updateLengthAndBytes(md5Digest, attrValue.stringValue());
255                 } else if (attrValue.binaryValue() != null) {
256                     md5Digest.update(BINARY_TYPE_FIELD_INDEX);
257                     updateLengthAndBytes(md5Digest, attrValue.binaryValue().asByteBuffer());
258                 } else if (attrValue.stringListValues() != null &&
259                            attrValue.stringListValues().size() > 0) {
260                     md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX);
261                     for (String strListMember : attrValue.stringListValues()) {
262                         updateLengthAndBytes(md5Digest, strListMember);
263                     }
264                 } else if (attrValue.binaryListValues() != null &&
265                            attrValue.binaryListValues().size() > 0) {
266                     md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX);
267                     for (SdkBytes byteListMember : attrValue.binaryListValues()) {
268                         updateLengthAndBytes(md5Digest, byteListMember.asByteBuffer());
269                     }
270                 }
271             }
272         } catch (Exception e) {
273             throw SdkClientException.builder()
274                                     .message("Unable to calculate the MD5 hash of the message attributes. " + e.getMessage())
275                                     .cause(e)
276                                     .build();
277         }
278 
279         String expectedMd5Hex = BinaryUtils.toHex(md5Digest.digest());
280         log.debug(() -> "Expected  MD5 of message attributes: " + expectedMd5Hex);
281         return expectedMd5Hex;
282     }
283 
284     /**
285      * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
286      * input String and the actual utf8-encoded byte values.
287      */
updateLengthAndBytes(MessageDigest digest, String str)288     private static void updateLengthAndBytes(MessageDigest digest, String str) {
289         byte[] utf8Encoded = str.getBytes(StandardCharsets.UTF_8);
290         ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(utf8Encoded.length);
291         digest.update(lengthBytes.array());
292         digest.update(utf8Encoded);
293     }
294 
295     /**
296      * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
297      * input ByteBuffer and all the bytes it contains.
298      */
updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue)299     private static void updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue) {
300         ByteBuffer readOnlyBuffer = binaryValue.asReadOnlyBuffer();
301         int size = readOnlyBuffer.remaining();
302         ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(size);
303         digest.update(lengthBytes.array());
304         digest.update(readOnlyBuffer);
305     }
306 }
307