1 /* 2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"). 5 * You may not use this file except in compliance with the License. 6 * A copy of the License is located at 7 * 8 * http://aws.amazon.com/apache2.0 9 * 10 * or in the "license" file accompanying this file. This file is distributed 11 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 * express or implied. See the License for the specific language governing 13 * permissions and limitations under the License. 14 */ 15 16 package software.amazon.awssdk.services.sqs.internal; 17 18 import java.nio.ByteBuffer; 19 import java.nio.charset.StandardCharsets; 20 import java.security.MessageDigest; 21 import java.util.ArrayList; 22 import java.util.Collections; 23 import java.util.HashMap; 24 import java.util.List; 25 import java.util.Map; 26 import software.amazon.awssdk.annotations.SdkInternalApi; 27 import software.amazon.awssdk.core.SdkBytes; 28 import software.amazon.awssdk.core.SdkRequest; 29 import software.amazon.awssdk.core.SdkResponse; 30 import software.amazon.awssdk.core.exception.SdkClientException; 31 import software.amazon.awssdk.core.interceptor.Context; 32 import software.amazon.awssdk.core.interceptor.ExecutionAttributes; 33 import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; 34 import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; 35 import software.amazon.awssdk.services.sqs.endpoints.SqsClientContextParams; 36 import software.amazon.awssdk.services.sqs.model.Message; 37 import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; 38 import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; 39 import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; 40 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; 41 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; 42 import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; 43 import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry; 44 import software.amazon.awssdk.services.sqs.model.SendMessageRequest; 45 import software.amazon.awssdk.services.sqs.model.SendMessageResponse; 46 import software.amazon.awssdk.utils.AttributeMap; 47 import software.amazon.awssdk.utils.BinaryUtils; 48 import software.amazon.awssdk.utils.Logger; 49 import software.amazon.awssdk.utils.Md5Utils; 50 51 /** 52 * SQS operations on sending and receiving messages will return the MD5 digest of the message body. 53 * This custom request handler will verify that the message is correctly received by SQS, by 54 * comparing the returned MD5 with the calculation according to the original request. 55 */ 56 @SdkInternalApi 57 public final class MessageMD5ChecksumInterceptor implements ExecutionInterceptor { 58 59 private static final int INTEGER_SIZE_IN_BYTES = 4; 60 private static final byte STRING_TYPE_FIELD_INDEX = 1; 61 private static final byte BINARY_TYPE_FIELD_INDEX = 2; 62 private static final byte STRING_LIST_TYPE_FIELD_INDEX = 3; 63 private static final byte BINARY_LIST_TYPE_FIELD_INDEX = 4; 64 65 /* 66 * Constant strings for composing error message. 67 */ 68 private static final String MD5_MISMATCH_ERROR_MESSAGE = 69 "MD5 returned by SQS does not match the calculation on the original request. " + 70 "(MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")"; 71 private static final String MD5_MISMATCH_ERROR_MESSAGE_WITH_ID = 72 "MD5 returned by SQS does not match the calculation on the original request. " + 73 "(Message ID: %s, MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")"; 74 private static final String MESSAGE_BODY = "message body"; 75 private static final String MESSAGE_ATTRIBUTES = "message attributes"; 76 77 private static final Logger log = Logger.loggerFor(MessageMD5ChecksumInterceptor.class); 78 79 @Override afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes)80 public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) { 81 SdkResponse response = context.response(); 82 SdkRequest originalRequest = context.request(); 83 84 if (response != null && validateMessageMD5Enabled(executionAttributes)) { 85 if (originalRequest instanceof SendMessageRequest) { 86 SendMessageRequest sendMessageRequest = (SendMessageRequest) originalRequest; 87 SendMessageResponse sendMessageResult = (SendMessageResponse) response; 88 sendMessageOperationMd5Check(sendMessageRequest, sendMessageResult); 89 90 } else if (originalRequest instanceof ReceiveMessageRequest) { 91 ReceiveMessageResponse receiveMessageResult = (ReceiveMessageResponse) response; 92 receiveMessageResultMd5Check(receiveMessageResult); 93 94 } else if (originalRequest instanceof SendMessageBatchRequest) { 95 SendMessageBatchRequest sendMessageBatchRequest = (SendMessageBatchRequest) originalRequest; 96 SendMessageBatchResponse sendMessageBatchResult = (SendMessageBatchResponse) response; 97 sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResult); 98 } 99 } 100 } 101 validateMessageMD5Enabled(ExecutionAttributes executionAttributes)102 private static boolean validateMessageMD5Enabled(ExecutionAttributes executionAttributes) { 103 AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS); 104 Boolean enableMd5Validation = clientContextParams.get(SqsClientContextParams.CHECKSUM_VALIDATION_ENABLED); 105 return enableMd5Validation == null || enableMd5Validation; 106 } 107 108 /** 109 * Throw an exception if the MD5 checksums returned in the SendMessageResponse do not match the 110 * client-side calculation based on the original message in the SendMessageRequest. 111 */ sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest, SendMessageResponse sendMessageResult)112 private static void sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest, 113 SendMessageResponse sendMessageResult) { 114 String messageBodySent = sendMessageRequest.messageBody(); 115 String bodyMd5Returned = sendMessageResult.md5OfMessageBody(); 116 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent); 117 if (!clientSideBodyMd5.equals(bodyMd5Returned)) { 118 throw SdkClientException.builder() 119 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5, 120 bodyMd5Returned)) 121 .build(); 122 } 123 124 Map<String, MessageAttributeValue> messageAttrSent = sendMessageRequest.messageAttributes(); 125 if (messageAttrSent != null && !messageAttrSent.isEmpty()) { 126 String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent); 127 String attrMd5Returned = sendMessageResult.md5OfMessageAttributes(); 128 if (!clientSideAttrMd5.equals(attrMd5Returned)) { 129 throw SdkClientException.builder() 130 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES, 131 clientSideAttrMd5, attrMd5Returned)) 132 .build(); 133 } 134 } 135 } 136 137 /** 138 * Throw an exception if the MD5 checksums included in the ReceiveMessageResponse do not match the 139 * client-side calculation on the received messages. 140 */ receiveMessageResultMd5Check(ReceiveMessageResponse receiveMessageResult)141 private static void receiveMessageResultMd5Check(ReceiveMessageResponse receiveMessageResult) { 142 if (receiveMessageResult.messages() != null) { 143 for (Message messageReceived : receiveMessageResult.messages()) { 144 String messageBody = messageReceived.body(); 145 String bodyMd5Returned = messageReceived.md5OfBody(); 146 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody); 147 if (!clientSideBodyMd5.equals(bodyMd5Returned)) { 148 throw SdkClientException.builder() 149 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, 150 clientSideBodyMd5, bodyMd5Returned)) 151 .build(); 152 } 153 154 Map<String, MessageAttributeValue> messageAttr = messageReceived.messageAttributes(); 155 if (messageAttr != null && !messageAttr.isEmpty()) { 156 String attrMd5Returned = messageReceived.md5OfMessageAttributes(); 157 String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr); 158 if (!clientSideAttrMd5.equals(attrMd5Returned)) { 159 throw SdkClientException.builder() 160 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES, 161 clientSideAttrMd5, attrMd5Returned)) 162 .build(); 163 } 164 } 165 } 166 } 167 } 168 169 /** 170 * Throw an exception if the MD5 checksums returned in the SendMessageBatchResponse do not match 171 * the client-side calculation based on the original messages in the SendMessageBatchRequest. 172 */ sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest, SendMessageBatchResponse sendMessageBatchResult)173 private static void sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest, 174 SendMessageBatchResponse sendMessageBatchResult) { 175 Map<String, SendMessageBatchRequestEntry> idToRequestEntryMap = new HashMap<>(); 176 if (sendMessageBatchRequest.entries() != null) { 177 for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) { 178 idToRequestEntryMap.put(entry.id(), entry); 179 } 180 } 181 182 if (sendMessageBatchResult.successful() != null) { 183 for (SendMessageBatchResultEntry entry : sendMessageBatchResult.successful()) { 184 String messageBody = idToRequestEntryMap.get(entry.id()).messageBody(); 185 String bodyMd5Returned = entry.md5OfMessageBody(); 186 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody); 187 if (!clientSideBodyMd5.equals(bodyMd5Returned)) { 188 throw SdkClientException.builder() 189 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_BODY, 190 entry.id(), clientSideBodyMd5, bodyMd5Returned)) 191 .build(); 192 } 193 194 Map<String, MessageAttributeValue> messageAttr = idToRequestEntryMap.get(entry.id()) 195 .messageAttributes(); 196 if (messageAttr != null && !messageAttr.isEmpty()) { 197 String attrMd5Returned = entry.md5OfMessageAttributes(); 198 String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr); 199 if (!clientSideAttrMd5.equals(attrMd5Returned)) { 200 throw SdkClientException.builder() 201 .message(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, 202 MESSAGE_ATTRIBUTES, entry.id(), clientSideAttrMd5, 203 attrMd5Returned)) 204 .build(); 205 } 206 } 207 } 208 } 209 } 210 211 /** 212 * Returns the hex-encoded MD5 hash String of the given message body. 213 */ calculateMessageBodyMd5(String messageBody)214 private static String calculateMessageBodyMd5(String messageBody) { 215 log.debug(() -> "Message body: " + messageBody); 216 byte[] expectedMd5; 217 try { 218 expectedMd5 = Md5Utils.computeMD5Hash(messageBody.getBytes(StandardCharsets.UTF_8)); 219 } catch (Exception e) { 220 throw SdkClientException.builder() 221 .message("Unable to calculate the MD5 hash of the message body. " + e.getMessage()) 222 .cause(e) 223 .build(); 224 } 225 String expectedMd5Hex = BinaryUtils.toHex(expectedMd5); 226 log.debug(() -> "Expected MD5 of message body: " + expectedMd5Hex); 227 return expectedMd5Hex; 228 } 229 230 /** 231 * Returns the hex-encoded MD5 hash String of the given message attributes. 232 */ calculateMessageAttributesMd5(final Map<String, MessageAttributeValue> messageAttributes)233 private static String calculateMessageAttributesMd5(final Map<String, MessageAttributeValue> messageAttributes) { 234 log.debug(() -> "Message attributes: " + messageAttributes); 235 List<String> sortedAttributeNames = new ArrayList<>(messageAttributes.keySet()); 236 Collections.sort(sortedAttributeNames); 237 238 MessageDigest md5Digest; 239 try { 240 md5Digest = MessageDigest.getInstance("MD5"); 241 242 for (String attrName : sortedAttributeNames) { 243 MessageAttributeValue attrValue = messageAttributes.get(attrName); 244 245 // Encoded Name 246 updateLengthAndBytes(md5Digest, attrName); 247 248 // Encoded Type 249 updateLengthAndBytes(md5Digest, attrValue.dataType()); 250 251 // Encoded Value 252 if (attrValue.stringValue() != null) { 253 md5Digest.update(STRING_TYPE_FIELD_INDEX); 254 updateLengthAndBytes(md5Digest, attrValue.stringValue()); 255 } else if (attrValue.binaryValue() != null) { 256 md5Digest.update(BINARY_TYPE_FIELD_INDEX); 257 updateLengthAndBytes(md5Digest, attrValue.binaryValue().asByteBuffer()); 258 } else if (attrValue.stringListValues() != null && 259 attrValue.stringListValues().size() > 0) { 260 md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX); 261 for (String strListMember : attrValue.stringListValues()) { 262 updateLengthAndBytes(md5Digest, strListMember); 263 } 264 } else if (attrValue.binaryListValues() != null && 265 attrValue.binaryListValues().size() > 0) { 266 md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX); 267 for (SdkBytes byteListMember : attrValue.binaryListValues()) { 268 updateLengthAndBytes(md5Digest, byteListMember.asByteBuffer()); 269 } 270 } 271 } 272 } catch (Exception e) { 273 throw SdkClientException.builder() 274 .message("Unable to calculate the MD5 hash of the message attributes. " + e.getMessage()) 275 .cause(e) 276 .build(); 277 } 278 279 String expectedMd5Hex = BinaryUtils.toHex(md5Digest.digest()); 280 log.debug(() -> "Expected MD5 of message attributes: " + expectedMd5Hex); 281 return expectedMd5Hex; 282 } 283 284 /** 285 * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the 286 * input String and the actual utf8-encoded byte values. 287 */ updateLengthAndBytes(MessageDigest digest, String str)288 private static void updateLengthAndBytes(MessageDigest digest, String str) { 289 byte[] utf8Encoded = str.getBytes(StandardCharsets.UTF_8); 290 ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(utf8Encoded.length); 291 digest.update(lengthBytes.array()); 292 digest.update(utf8Encoded); 293 } 294 295 /** 296 * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the 297 * input ByteBuffer and all the bytes it contains. 298 */ updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue)299 private static void updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue) { 300 ByteBuffer readOnlyBuffer = binaryValue.asReadOnlyBuffer(); 301 int size = readOnlyBuffer.remaining(); 302 ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(size); 303 digest.update(lengthBytes.array()); 304 digest.update(readOnlyBuffer); 305 } 306 } 307