1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + | package aws.sdk.kotlin.services.sqs
|
6 + |
|
7 + | import aws.sdk.kotlin.runtime.ClientException
|
8 + | import aws.sdk.kotlin.services.sqs.internal.ValidationEnabled
|
9 + | import aws.sdk.kotlin.services.sqs.internal.ValidationScope
|
10 + | import aws.sdk.kotlin.services.sqs.model.*
|
11 + | import aws.smithy.kotlin.runtime.client.ResponseInterceptorContext
|
12 + | import aws.smithy.kotlin.runtime.collections.AttributeKey
|
13 + | import aws.smithy.kotlin.runtime.hashing.md5
|
14 + | import aws.smithy.kotlin.runtime.http.interceptors.ChecksumMismatchException
|
15 + | import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
|
16 + | import aws.smithy.kotlin.runtime.http.request.HttpRequest
|
17 + | import aws.smithy.kotlin.runtime.http.response.HttpResponse
|
18 + | import aws.smithy.kotlin.runtime.io.SdkBuffer
|
19 + | import aws.smithy.kotlin.runtime.telemetry.logging.Logger
|
20 + | import aws.smithy.kotlin.runtime.telemetry.logging.error
|
21 + | import aws.smithy.kotlin.runtime.telemetry.logging.logger
|
22 + | import aws.smithy.kotlin.runtime.util.asyncLazy
|
23 + | import kotlinx.coroutines.runBlocking
|
24 + | import kotlin.coroutines.coroutineContext
|
25 + |
|
26 + | private const val STRING_TYPE_FIELD_INDEX: Byte = 1
|
27 + | private const val BINARY_TYPE_FIELD_INDEX: Byte = 2
|
28 + | private const val STRING_LIST_TYPE_FIELD_INDEX: Byte = 3
|
29 + | private const val BINARY_LIST_TYPE_FIELD_INDEX: Byte = 4
|
30 + |
|
31 + | /**
|
32 + | * Interceptor that validates MD5 checksums for SQS message operations.
|
33 + | *
|
34 + | * This interceptor performs client-side validation of MD5 checksums returned by SQS to ensure
|
35 + | * message integrity during transmission. It validates the following components:
|
36 + | * - Message body
|
37 + | * - Message attributes
|
38 + | * - Message system attributes
|
39 + | *
|
40 + | * The validation behavior can be configured using:
|
41 + | * - [checksumValidationEnabled] - Controls when validation occurs (`ALWAYS`, `WHEN_SENDING`, `WHEN_RECEIVING`, `NEVER`)
|
42 + | * - [checksumValidationScopes] - Specifies which message components to validate
|
43 + | *
|
44 + | * Supported operations:
|
45 + | * - SendMessage
|
46 + | * - SendMessageBatch
|
47 + | * - ReceiveMessage
|
48 + | */
|
49 + | @OptIn(ExperimentalStdlibApi::class)
|
50 + | public class SqsMd5ChecksumValidationInterceptor(
|
51 + | private val validationEnabled: ValidationEnabled,
|
52 + | private val validationScopes: Set<ValidationScope>,
|
53 + | ) : HttpInterceptor {
|
54 + | public companion object {
|
55 + | private val isMd5Available = asyncLazy {
|
56 + | try {
|
57 + | "MD5".encodeToByteArray().md5()
|
58 + | true
|
59 + | } catch (e: Exception) {
|
60 + | coroutineContext.error<SqsMd5ChecksumValidationInterceptor>(e) {
|
61 + | "MD5 checksums are not available (likely due to FIPS mode). Checksum validation will be disabled."
|
62 + | }
|
63 + | false
|
64 + | }
|
65 + | }
|
66 + | }
|
67 + |
|
68 + | override fun readAfterExecution(context: ResponseInterceptorContext<Any, Any, HttpRequest?, HttpResponse?>) {
|
69 + | if (validationEnabled == ValidationEnabled.NEVER || runBlocking { !isMd5Available.get() }) return
|
70 + |
|
71 + | val logger = context.executionContext.coroutineContext.logger<SqsMd5ChecksumValidationInterceptor>()
|
72 + |
|
73 + | val request = context.request
|
74 + |
|
75 + | context.response.getOrNull()?.let { response ->
|
76 + | when (request) {
|
77 + | is SendMessageRequest -> {
|
78 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
79 + |
|
80 + | val sendMessageResponse = response as SendMessageResponse
|
81 + | sendMessageOperationMd5Check(request, sendMessageResponse, logger)
|
82 + | }
|
83 + |
|
84 + | is ReceiveMessageRequest -> {
|
85 + | if (validationEnabled == ValidationEnabled.WHEN_SENDING) return
|
86 + |
|
87 + | val receiveMessageResponse = response as ReceiveMessageResponse
|
88 + | receiveMessageResultMd5Check(receiveMessageResponse, logger)
|
89 + | }
|
90 + |
|
91 + | is SendMessageBatchRequest -> {
|
92 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
93 + |
|
94 + | val sendMessageBatchResponse = response as SendMessageBatchResponse
|
95 + | sendMessageBatchOperationMd5Check(request, sendMessageBatchResponse, logger)
|
96 + | }
|
97 + | }
|
98 + | }
|
99 + |
|
100 + | // Sets validation flag in execution context for e2e test assertions
|
101 + | val checksumValidated: AttributeKey<Boolean> = AttributeKey("checksumValidated")
|
102 + | context.executionContext[checksumValidated] = true
|
103 + | }
|
104 + |
|
105 + | private fun sendMessageOperationMd5Check(
|
106 + | sendMessageRequest: SendMessageRequest,
|
107 + | sendMessageResponse: SendMessageResponse,
|
108 + | logger: Logger,
|
109 + | ) {
|
110 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
111 + | val messageBodyMd5Returned = sendMessageResponse.md5OfMessageBody
|
112 + | val messageBodySent = sendMessageRequest.messageBody
|
113 + |
|
114 + | if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodySent.isNullOrEmpty()) {
|
115 + | logger.debug { "Validating message body MD5 checksum for SendMessage" }
|
116 + |
|
117 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent)
|
118 + |
|
119 + | validateMd5(clientSideBodyMd5, messageBodyMd5Returned)
|
120 + |
|
121 + | logger.debug { "Message body MD5 checksum for SendMessage validated" }
|
122 + | }
|
123 + | }
|
124 + |
|
125 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
126 + | val messageAttrMd5Returned = sendMessageResponse.md5OfMessageAttributes
|
127 + | val messageAttrSent = sendMessageRequest.messageAttributes
|
128 + |
|
129 + | if (!messageAttrMd5Returned.isNullOrEmpty() && !messageAttrSent.isNullOrEmpty()) {
|
130 + | logger.debug { "Validating message attribute MD5 checksum for SendMessage" }
|
131 + |
|
132 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
133 + |
|
134 + | validateMd5(clientSideAttrMd5, messageAttrMd5Returned)
|
135 + |
|
136 + | logger.debug { "Message attribute MD5 checksum for SendMessage validated" }
|
137 + | }
|
138 + | }
|
139 + |
|
140 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
141 + | val messageSysAttrMD5Returned = sendMessageResponse.md5OfMessageSystemAttributes
|
142 + | val messageSysAttrSent = sendMessageRequest.messageSystemAttributes
|
143 + |
|
144 + | if (!messageSysAttrMD5Returned.isNullOrEmpty() && !messageSysAttrSent.isNullOrEmpty()) {
|
145 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessage" }
|
146 + |
|
147 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
148 + |
|
149 + | validateMd5(clientSideSysAttrMd5, messageSysAttrMD5Returned)
|
150 + |
|
151 + | logger.debug { "Message system attribute MD5 checksum for SendMessage validated" }
|
152 + | }
|
153 + | }
|
154 + | }
|
155 + |
|
156 + | private fun receiveMessageResultMd5Check(receiveMessageResponse: ReceiveMessageResponse, logger: Logger) {
|
157 + | receiveMessageResponse.messages?.forEach { messageReceived ->
|
158 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
159 + | val messageBodyMd5Returned = messageReceived.md5OfBody
|
160 + | val messageBodyReturned = messageReceived.body
|
161 + |
|
162 + | if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodyReturned.isNullOrEmpty()) {
|
163 + | logger.debug { "Validating message body MD5 checksum for ReceiveMessage" }
|
164 + |
|
165 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodyReturned)
|
166 + |
|
167 + | validateMd5(clientSideBodyMd5, messageBodyMd5Returned)
|
168 + |
|
169 + | logger.debug { "Message body MD5 checksum for ReceiveMessage validated" }
|
170 + | }
|
171 + | }
|
172 + |
|
173 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
174 + | val messageAttrMd5Returned = messageReceived.md5OfMessageAttributes
|
175 + | val messageAttrReturned = messageReceived.messageAttributes
|
176 + |
|
177 + | if (!messageAttrMd5Returned.isNullOrEmpty() && !messageAttrReturned.isNullOrEmpty()) {
|
178 + | logger.debug { "Validating message attribute MD5 checksum for ReceiveMessage" }
|
179 + |
|
180 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrReturned)
|
181 + |
|
182 + | validateMd5(clientSideAttrMd5, messageAttrMd5Returned)
|
183 + |
|
184 + | logger.debug { "Message attribute MD5 checksum for ReceiveMessage validated" }
|
185 + | }
|
186 + | }
|
187 + | }
|
188 + | }
|
189 + |
|
190 + | private fun sendMessageBatchOperationMd5Check(
|
191 + | sendMessageBatchRequest: SendMessageBatchRequest,
|
192 + | sendMessageBatchResponse: SendMessageBatchResponse,
|
193 + | logger: Logger,
|
194 + | ) {
|
195 + | val idToRequestEntry = sendMessageBatchRequest
|
196 + | .entries
|
197 + | .orEmpty()
|
198 + | .associateBy { it.id }
|
199 + |
|
200 + | for (entry in sendMessageBatchResponse.successful) {
|
201 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
202 + | val messageBodyMd5Returned = entry.md5OfMessageBody
|
203 + | val messageBodySent = idToRequestEntry[entry.id]?.messageBody
|
204 + |
|
205 + | if (!messageBodyMd5Returned.isNullOrEmpty() && !messageBodySent.isNullOrEmpty()) {
|
206 + | logger.debug { "Validating message body MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
207 + |
|
208 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent)
|
209 + |
|
210 + | validateMd5(clientSideBodyMd5, messageBodyMd5Returned)
|
211 + |
|
212 + | logger.debug { "Message body MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
213 + | }
|
214 + | }
|
215 + |
|
216 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
217 + | val messageAttrMD5Returned = entry.md5OfMessageAttributes
|
218 + | val messageAttrSent = idToRequestEntry[entry.id]?.messageAttributes
|
219 + |
|
220 + | if (!messageAttrMD5Returned.isNullOrEmpty() && !messageAttrSent.isNullOrEmpty()) {
|
221 + | logger.debug { "Validating message attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
222 + |
|
223 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
224 + |
|
225 + | validateMd5(clientSideAttrMd5, messageAttrMD5Returned)
|
226 + |
|
227 + | logger.debug { "Message attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
228 + | }
|
229 + | }
|
230 + |
|
231 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
232 + | val messageSysAttrMD5Returned = entry.md5OfMessageSystemAttributes
|
233 + | val messageSysAttrSent = idToRequestEntry[entry.id]?.messageSystemAttributes
|
234 + |
|
235 + | if (!messageSysAttrMD5Returned.isNullOrEmpty() && !messageSysAttrSent.isNullOrEmpty()) {
|
236 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
237 + |
|
238 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
239 + |
|
240 + | validateMd5(clientSideSysAttrMd5, messageSysAttrMD5Returned)
|
241 + |
|
242 + | logger.debug { "Message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId} validated" }
|
243 + | }
|
244 + | }
|
245 + | }
|
246 + | }
|
247 + |
|
248 + | private fun validateMd5(clientSideMd5: String, md5Returned: String) {
|
249 + | if (clientSideMd5 != md5Returned) {
|
250 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideMd5 but was $md5Returned")
|
251 + | }
|
252 + | }
|
253 + |
|
254 + | private fun calculateMessageBodyMd5(messageBody: String) =
|
255 + | messageBody.encodeToByteArray().md5().toHexString()
|
256 + |
|
257 + | /**
|
258 + | * Calculates the MD5 digest for message attributes according to SQS specifications.
|
259 + | * https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html#sqs-attributes-md5-message-digest-calculation
|
260 + | */
|
261 + | private fun calculateMessageAttributesMd5(messageAttributes: Map<String, MessageAttributeValue>): String {
|
262 + | val buffer = SdkBuffer()
|
263 + |
|
264 + | messageAttributes
|
265 + | .entries
|
266 + | .sortedBy { (attributeName, _) -> attributeName }
|
267 + | .forEach { (attributeName, attributeValue) ->
|
268 + | updateLengthAndBytes(buffer, attributeName)
|
269 + | updateLengthAndBytes(buffer, attributeValue.dataType)
|
270 + | when {
|
271 + | attributeValue.stringValue != null -> updateForStringType(buffer, attributeValue.stringValue)
|
272 + | attributeValue.binaryValue != null -> updateForBinaryType(buffer, attributeValue.binaryValue)
|
273 + | !attributeValue.stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, attributeValue.stringListValues)
|
274 + | !attributeValue.binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, attributeValue.binaryListValues)
|
275 + | else -> throw ClientException("No value type found for attribute $attributeName")
|
276 + | }
|
277 + | }
|
278 + |
|
279 + | val payload = buffer.readByteArray()
|
280 + | return payload.md5().toHexString()
|
281 + | }
|
282 + |
|
283 + | private fun calculateMessageSystemAttributesMd5(
|
284 + | messageSystemAttributes: Map<MessageSystemAttributeNameForSends, MessageSystemAttributeValue>,
|
285 + | ): String {
|
286 + | val buffer = SdkBuffer()
|
287 + |
|
288 + | messageSystemAttributes
|
289 + | .entries
|
290 + | .sortedBy { (systemAttributeName, _) -> systemAttributeName.value }
|
291 + | .forEach { (systemAttributeName, systemAttributeValue) ->
|
292 + | updateLengthAndBytes(buffer, systemAttributeName.value)
|
293 + | updateLengthAndBytes(buffer, systemAttributeValue.dataType)
|
294 + | when {
|
295 + | systemAttributeValue.stringValue != null -> updateForStringType(buffer, systemAttributeValue.stringValue)
|
296 + | systemAttributeValue.binaryValue != null -> updateForBinaryType(buffer, systemAttributeValue.binaryValue)
|
297 + | !systemAttributeValue.stringListValues.isNullOrEmpty() -> updateForStringListType(buffer, systemAttributeValue.stringListValues)
|
298 + | !systemAttributeValue.binaryListValues.isNullOrEmpty() -> updateForBinaryListType(buffer, systemAttributeValue.binaryListValues)
|
299 + | else -> throw ClientException("No value type found for system attribute $systemAttributeName")
|
300 + | }
|
301 + | }
|
302 + |
|
303 + | val payload = buffer.readByteArray()
|
304 + | return payload.md5().toHexString()
|
305 + | }
|
306 + |
|
307 + | private fun updateForStringType(buffer: SdkBuffer, value: String) {
|
308 + | buffer.writeByte(STRING_TYPE_FIELD_INDEX)
|
309 + | updateLengthAndBytes(buffer, value)
|
310 + | }
|
311 + |
|
312 + | private fun updateForBinaryType(buffer: SdkBuffer, value: ByteArray) {
|
313 + | buffer.writeByte(BINARY_TYPE_FIELD_INDEX)
|
314 + | updateLengthAndBytes(buffer, value)
|
315 + | }
|
316 + |
|
317 + | private fun updateForStringListType(buffer: SdkBuffer, values: List<String>) {
|
318 + | buffer.writeByte(STRING_LIST_TYPE_FIELD_INDEX)
|
319 + | values.forEach { value ->
|
320 + | updateLengthAndBytes(buffer, value)
|
321 + | }
|
322 + | }
|
323 + |
|
324 + | private fun updateForBinaryListType(buffer: SdkBuffer, values: List<ByteArray>) {
|
325 + | buffer.writeByte(BINARY_LIST_TYPE_FIELD_INDEX)
|
326 + | values.forEach { value ->
|
327 + | updateLengthAndBytes(buffer, value)
|
328 + | }
|
329 + | }
|
330 + |
|
331 + | private fun updateLengthAndBytes(buffer: SdkBuffer, stringValue: String) =
|
332 + | updateLengthAndBytes(buffer, stringValue.encodeToByteArray())
|
333 + |
|
334 + | /**
|
335 + | * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
|
336 + | * input binaryValue and all the bytes it contains.
|
337 + | */
|
338 + | private fun updateLengthAndBytes(buffer: SdkBuffer, binaryValue: ByteArray) {
|
339 + | buffer.writeInt(binaryValue.size)
|
340 + | buffer.write(binaryValue)
|
341 + | }
|
342 + | }
|