1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + | package aws.sdk.kotlin.services.sqs
|
6 + |
|
7 + | import aws.sdk.kotlin.services.sqs.internal.ValidationEnabled
|
8 + | import aws.sdk.kotlin.services.sqs.internal.ValidationScope
|
9 + | import aws.sdk.kotlin.services.sqs.model.*
|
10 + | import aws.smithy.kotlin.runtime.client.ResponseInterceptorContext
|
11 + | import aws.smithy.kotlin.runtime.hashing.md5
|
12 + | import aws.smithy.kotlin.runtime.http.interceptors.ChecksumMismatchException
|
13 + | import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
|
14 + | import aws.smithy.kotlin.runtime.http.request.HttpRequest
|
15 + | import aws.smithy.kotlin.runtime.http.response.HttpResponse
|
16 + | import aws.smithy.kotlin.runtime.io.SdkBuffer
|
17 + | import aws.smithy.kotlin.runtime.telemetry.logging.Logger
|
18 + | import aws.smithy.kotlin.runtime.telemetry.logging.logger
|
19 + |
|
20 + | /**
|
21 + | * Interceptor that validates MD5 checksums for SQS message operations.
|
22 + | *
|
23 + | * This interceptor performs client-side validation of MD5 checksums returned by SQS to ensure
|
24 + | * message integrity during transmission. It validates the following components:
|
25 + | * - Message body
|
26 + | * - Message attributes
|
27 + | * - Message system attributes
|
28 + | *
|
29 + | * The validation behavior can be configured using:
|
30 + | * - [checksumValidationEnabled] - Controls when validation occurs (ALWAYS, WHEN_SENDING, WHEN_RECEIVING, NEVER)
|
31 + | * - [checksumValidationScopes] - Specifies which message components to validate
|
32 + | *
|
33 + | * Supported operations:
|
34 + | * - SendMessage
|
35 + | * - SendMessageBatch
|
36 + | * - ReceiveMessage
|
37 + | */
|
38 + | @OptIn(ExperimentalStdlibApi::class)
|
39 + | public class SqsMd5ChecksumValidationInterceptor(
|
40 + | private val validationEnabled: ValidationEnabled?,
|
41 + | private val validationScopes: Set<ValidationScope>,
|
42 + | ) : HttpInterceptor {
|
43 + | public companion object {
|
44 + | private const val STRING_TYPE_FIELD_INDEX: Byte = 1
|
45 + | private const val BINARY_TYPE_FIELD_INDEX: Byte = 2
|
46 + | private const val STRING_LIST_TYPE_FIELD_INDEX: Byte = 3
|
47 + | private const val BINARY_LIST_TYPE_FIELD_INDEX: Byte = 4
|
48 + | }
|
49 + |
|
50 + | override fun readAfterExecution(context: ResponseInterceptorContext<Any, Any, HttpRequest?, HttpResponse?>) {
|
51 + | if (validationEnabled == ValidationEnabled.NEVER) return
|
52 + |
|
53 + | val logger = context.executionContext.coroutineContext.logger<SqsMd5ChecksumValidationInterceptor>()
|
54 + |
|
55 + | // Test MD5 availability
|
56 + | try {
|
57 + | "MD5".encodeToByteArray().md5()
|
58 + | } catch (e: Exception) {
|
59 + | logger.error { "MD5 checksums are not available (likely due to FIPS mode). Checksum validation will be disabled." }
|
60 + | return
|
61 + | }
|
62 + |
|
63 + | val request = context.request
|
64 + | val response = context.response.getOrNull()
|
65 + |
|
66 + | if (response != null) {
|
67 + | when (request) {
|
68 + | is SendMessageRequest -> {
|
69 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
70 + |
|
71 + | val sendMessageRequest = request as SendMessageRequest
|
72 + | val sendMessageResponse = response as SendMessageResponse
|
73 + | sendMessageOperationMd5Check(sendMessageRequest, sendMessageResponse, logger)
|
74 + | }
|
75 + |
|
76 + | is ReceiveMessageRequest -> {
|
77 + | if (validationEnabled == ValidationEnabled.WHEN_SENDING) return
|
78 + |
|
79 + | val receiveMessageResponse = response as ReceiveMessageResponse
|
80 + | receiveMessageResultMd5Check(receiveMessageResponse, logger)
|
81 + | }
|
82 + |
|
83 + | is SendMessageBatchRequest -> {
|
84 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
85 + |
|
86 + | val sendMessageBatchRequest = request as SendMessageBatchRequest
|
87 + | val sendMessageBatchResponse = response as SendMessageBatchResponse
|
88 + | sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResponse, logger)
|
89 + | }
|
90 + | }
|
91 + | }
|
92 + | }
|
93 + |
|
94 + | private fun sendMessageOperationMd5Check(
|
95 + | sendMessageRequest: SendMessageRequest,
|
96 + | sendMessageResponse: SendMessageResponse,
|
97 + | logger: Logger,
|
98 + | ) {
|
99 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
100 + | val messageBodySent = sendMessageRequest.messageBody
|
101 + |
|
102 + | if (!messageBodySent.isNullOrEmpty()) {
|
103 + | logger.debug { "Validating message body MD5 checksum for SendMessage" }
|
104 + |
|
105 + | val bodyMD5Returned = sendMessageResponse.md5OfMessageBody
|
106 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent)
|
107 + | if (clientSideBodyMd5 != bodyMD5Returned) {
|
108 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMD5Returned")
|
109 + | }
|
110 + |
|
111 + | logger.debug { "Message body MD5 checksum for SendMessage validated" }
|
112 + | }
|
113 + | }
|
114 + |
|
115 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
116 + | val messageAttrSent = sendMessageRequest.messageAttributes
|
117 + | if (!messageAttrSent.isNullOrEmpty()) {
|
118 + | logger.debug { "Validating message attribute MD5 checksum for SendMessage" }
|
119 + |
|
120 + | val messageAttrMD5Returned = sendMessageResponse.md5OfMessageAttributes
|
121 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
122 + | if (clientSideAttrMd5 != messageAttrMD5Returned) {
|
123 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $messageAttrMD5Returned")
|
124 + | }
|
125 + |
|
126 + | logger.debug { "Message attribute MD5 checksum for SendMessage validated" }
|
127 + | }
|
128 + | }
|
129 + |
|
130 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
131 + | val messageSysAttrSent = sendMessageRequest.messageSystemAttributes
|
132 + | if (!messageSysAttrSent.isNullOrEmpty()) {
|
133 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessage" }
|
134 + |
|
135 + | val messageSysAttrMD5Returned = sendMessageResponse.md5OfMessageSystemAttributes
|
136 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
137 + | if (clientSideSysAttrMd5 != messageSysAttrMD5Returned) {
|
138 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideSysAttrMd5 but was $messageSysAttrMD5Returned")
|
139 + | }
|
140 + |
|
141 + | logger.debug { "Message system attribute MD5 checksum for SendMessage validated" }
|
142 + | }
|
143 + | }
|
144 + | }
|
145 + |
|
146 + | private fun receiveMessageResultMd5Check(receiveMessageResponse: ReceiveMessageResponse, logger: Logger) {
|
147 + | receiveMessageResponse.messages?.forEach { messageReceived ->
|
148 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
149 + | val messageBody = messageReceived.body
|
150 + | if (!messageBody.isNullOrEmpty()) {
|
151 + | logger.debug { "Validating message body MD5 checksum for ReceiveMessage" }
|
152 + |
|
153 + | val bodyMd5Returned = messageReceived.md5OfBody
|
154 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBody)
|
155 + | if (clientSideBodyMd5 != bodyMd5Returned) {
|
156 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMd5Returned")
|
157 + | }
|
158 + |
|
159 + | logger.debug { "Message body MD5 checksum for ReceiveMessage validated " }
|
160 + | }
|
161 + | }
|
162 + |
|
163 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
164 + | val messageAttr = messageReceived.messageAttributes
|
165 + |
|
166 + | if (!messageAttr.isNullOrEmpty()) {
|
167 + | logger.debug { "Validating message attribute MD5 checksum for ReceiveMessage" }
|
168 + |
|
169 + | val attrMd5Returned = messageReceived.md5OfMessageAttributes
|
170 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr)
|
171 + | if (clientSideAttrMd5 != attrMd5Returned) {
|
172 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $attrMd5Returned")
|
173 + | }
|
174 + |
|
175 + | logger.debug { "Message attribute MD5 checksum for ReceiveMessage validated " }
|
176 + | }
|
177 + | }
|
178 + | }
|
179 + | }
|
180 + |
|
181 + | private fun sendMessageBatchOperationMd5Check(
|
182 + | sendMessageBatchRequest: SendMessageBatchRequest,
|
183 + | sendMessageBatchResponse: SendMessageBatchResponse,
|
184 + | logger: Logger,
|
185 + | ) {
|
186 + | val idToRequestEntryMap = sendMessageBatchRequest
|
187 + | .entries
|
188 + | .orEmpty()
|
189 + | .associateBy { it.id }
|
190 + |
|
191 + | for (entry in sendMessageBatchResponse.successful) {
|
192 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
193 + | val messageBody = idToRequestEntryMap[entry.id]?.messageBody
|
194 + |
|
195 + | if (!messageBody.isNullOrEmpty()) {
|
196 + | logger.debug { "Validating message body MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
197 + |
|
198 + | val bodyMd5Returned = entry.md5OfMessageBody
|
199 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBody)
|
200 + | if (clientSideBodyMd5 != bodyMd5Returned) {
|
201 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMd5Returned")
|
202 + | }
|
203 + |
|
204 + | logger.debug { "Message body MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
205 + | }
|
206 + | }
|
207 + |
|
208 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
209 + | val messageAttrSent = idToRequestEntryMap[entry.id]?.messageAttributes
|
210 + | if (!messageAttrSent.isNullOrEmpty()) {
|
211 + | logger.debug { "Validating message attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
212 + |
|
213 + | val messageAttrMD5Returned = entry.md5OfMessageAttributes
|
214 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
215 + | if (clientSideAttrMd5 != messageAttrMD5Returned) {
|
216 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $messageAttrMD5Returned")
|
217 + | }
|
218 + |
|
219 + | logger.debug { "Message attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
220 + | }
|
221 + | }
|
222 + |
|
223 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
224 + | val messageSysAttrSent = idToRequestEntryMap[entry.id]?.messageSystemAttributes
|
225 + | if (!messageSysAttrSent.isNullOrEmpty()) {
|
226 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
227 + |
|
228 + | val messageSysAttrMD5Returned = entry.md5OfMessageSystemAttributes
|
229 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
230 + | if (clientSideSysAttrMd5 != messageSysAttrMD5Returned) {
|
231 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideSysAttrMd5 but was $messageSysAttrMD5Returned")
|
232 + | }
|
233 + |
|
234 + | logger.debug { "Message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
235 + | }
|
236 + | }
|
237 + | }
|
238 + | }
|
239 + |
|
240 + | private fun calculateMessageBodyMd5(messageBody: String): String {
|
241 + | val expectedMD5 = messageBody.encodeToByteArray().md5()
|
242 + | val expectedMD5Hex = expectedMD5.toHexString()
|
243 + |
|
244 + | return expectedMD5Hex
|
245 + | }
|
246 + |
|
247 + | /**
|
248 + | * Calculates the MD5 digest for message attributes according to SQS specifications.
|
249 + | * https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html#sqs-attributes-md5-message-digest-calculation
|
250 + | */
|
251 + | private fun calculateMessageAttributesMd5(messageAttributes: Map<String, MessageAttributeValue>): String {
|
252 + | val buffer = SdkBuffer()
|
253 + |
|
254 + | messageAttributes
|
255 + | .entries
|
256 + | .sortedBy { (name, _) -> name }
|
257 + | .forEach { (attributeName, attributeValue) ->
|
258 + | updateLengthAndBytes(buffer, attributeName)
|
259 + |
|
260 + | updateLengthAndBytes(buffer, attributeValue.dataType)
|
261 + |
|
262 + | val stringValue = attributeValue.stringValue
|
263 + | val binaryValue = attributeValue.binaryValue
|
264 + | val stringListValues = attributeValue.stringListValues
|
265 + | val binaryListValues = attributeValue.binaryListValues
|
266 + |
|
267 + | when {
|
268 + | stringValue != null -> updateForStringType(buffer, stringValue)
|
269 + | binaryValue != null -> updateForBinaryType(buffer, binaryValue)
|
270 + | !stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, stringListValues)
|
271 + | !binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, binaryListValues)
|
272 + | }
|
273 + | }
|
274 + |
|
275 + | val payload = buffer.readByteArray()
|
276 + | return payload.md5().toHexString()
|
277 + | }
|
278 + |
|
279 + | private fun calculateMessageSystemAttributesMd5(
|
280 + | messageSysAttrs: Map<MessageSystemAttributeNameForSends, MessageSystemAttributeValue>,
|
281 + | ): String {
|
282 + | val buffer = SdkBuffer()
|
283 + |
|
284 + | messageSysAttrs
|
285 + | .entries
|
286 + | .sortedBy { (name, _) -> name.value }
|
287 + | .forEach { (attributeName, attributeValue) ->
|
288 + | updateLengthAndBytes(buffer, attributeName.value)
|
289 + |
|
290 + | updateLengthAndBytes(buffer, attributeValue.dataType)
|
291 + |
|
292 + | val stringValue = attributeValue.stringValue
|
293 + | val binaryValue = attributeValue.binaryValue
|
294 + | val stringListValues = attributeValue.stringListValues
|
295 + | val binaryListValues = attributeValue.binaryListValues
|
296 + |
|
297 + | when {
|
298 + | stringValue != null -> updateForStringType(buffer, stringValue)
|
299 + | binaryValue != null -> updateForBinaryType(buffer, binaryValue)
|
300 + | !stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, stringListValues)
|
301 + | !binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, binaryListValues)
|
302 + | }
|
303 + | }
|
304 + |
|
305 + | val payload = buffer.readByteArray()
|
306 + | return payload.md5().toHexString()
|
307 + | }
|
308 + |
|
309 + | private fun updateForStringType(buffer: SdkBuffer, value: String) {
|
310 + | buffer.writeByte(STRING_TYPE_FIELD_INDEX)
|
311 + | updateLengthAndBytes(buffer, value)
|
312 + | }
|
313 + |
|
314 + | private fun updateForBinaryType(buffer: SdkBuffer, value: ByteArray) {
|
315 + | buffer.writeByte(BINARY_TYPE_FIELD_INDEX)
|
316 + | updateLengthAndBytes(buffer, value)
|
317 + | }
|
318 + |
|
319 + | private fun updateForStringListType(buffer: SdkBuffer, values: List<String>) {
|
320 + | buffer.writeByte(STRING_LIST_TYPE_FIELD_INDEX)
|
321 + | values.forEach { value ->
|
322 + | updateLengthAndBytes(buffer, value)
|
323 + | }
|
324 + | }
|
325 + |
|
326 + | private fun updateForBinaryListType(buffer: SdkBuffer, values: List<ByteArray>) {
|
327 + | buffer.writeByte(BINARY_LIST_TYPE_FIELD_INDEX)
|
328 + | values.forEach { value ->
|
329 + | updateLengthAndBytes(buffer, value)
|
330 + | }
|
331 + | }
|
332 + |
|
333 + | private fun updateLengthAndBytes(buffer: SdkBuffer, stringValue: String) =
|
334 + | updateLengthAndBytes(buffer, stringValue.encodeToByteArray())
|
335 + |
|
336 + | /**
|
337 + | * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
|
338 + | * input binaryValue and all the bytes it contains.
|
339 + | */
|
340 + | private fun updateLengthAndBytes(buffer: SdkBuffer, binaryValue: ByteArray) {
|
341 + | buffer.writeInt(binaryValue.size)
|
342 + | buffer.write(binaryValue)
|
343 + | }
|
344 + | }
|