1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + | package aws.sdk.kotlin.services.sqs
|
6 + |
|
7 + | import aws.sdk.kotlin.services.sqs.internal.ValidationEnabled
|
8 + | import aws.sdk.kotlin.services.sqs.internal.ValidationScope
|
9 + | import aws.sdk.kotlin.services.sqs.model.*
|
10 + | import aws.smithy.kotlin.runtime.ClientException
|
11 + | import aws.smithy.kotlin.runtime.InternalApi
|
12 + | import aws.smithy.kotlin.runtime.client.ResponseInterceptorContext
|
13 + | import aws.smithy.kotlin.runtime.hashing.Md5
|
14 + | import aws.smithy.kotlin.runtime.hashing.md5
|
15 + | import aws.smithy.kotlin.runtime.http.interceptors.ChecksumMismatchException
|
16 + | import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
|
17 + | import aws.smithy.kotlin.runtime.http.request.HttpRequest
|
18 + | import aws.smithy.kotlin.runtime.http.response.HttpResponse
|
19 + | import aws.smithy.kotlin.runtime.telemetry.logging.Logger
|
20 + | import aws.smithy.kotlin.runtime.telemetry.logging.logger
|
21 + | import kotlin.collections.Map
|
22 + | import kotlin.collections.Set
|
23 + | import kotlin.collections.hashMapOf
|
24 + | import kotlin.collections.isNullOrEmpty
|
25 + | import kotlin.collections.set
|
26 + | import kotlin.collections.sorted
|
27 + | import kotlin.collections.sortedBy
|
28 + |
|
29 + | /**
|
30 + | * Interceptor that validates MD5 checksums for SQS message operations.
|
31 + | *
|
32 + | * This interceptor performs client-side validation of MD5 checksums returned by SQS to ensure
|
33 + | * message integrity during transmission. It validates the following components:
|
34 + | * - Message body
|
35 + | * - Message attributes
|
36 + | * - Message system attributes
|
37 + | *
|
38 + | * The validation behavior can be configured using:
|
39 + | * - [checksumValidationEnabled] - Controls when validation occurs (ALWAYS, WHEN_SENDING, WHEN_RECEIVING, NEVER)
|
40 + | * - [checksumValidationScopes] - Specifies which message components to validate
|
41 + | *
|
42 + | * Supported operations:
|
43 + | * - SendMessage
|
44 + | * - SendMessageBatch
|
45 + | * - ReceiveMessage
|
46 + | */
|
47 + | @OptIn(InternalApi::class, ExperimentalStdlibApi::class)
|
48 + | public class SqsMd5ChecksumValidationInterceptor(
|
49 + | private val validationEnabled: ValidationEnabled?,
|
50 + | private val validationScopes: Set<ValidationScope>,
|
51 + | ) : HttpInterceptor {
|
52 + | public companion object {
|
53 + | private const val STRING_TYPE_FIELD_INDEX: Byte = 1
|
54 + | private const val BINARY_TYPE_FIELD_INDEX: Byte = 2
|
55 + | private const val STRING_LIST_TYPE_FIELD_INDEX: Byte = 3
|
56 + | private const val BINARY_LIST_TYPE_FIELD_INDEX: Byte = 4
|
57 + |
|
58 + | private lateinit var logger: Logger
|
59 + |
|
60 + | private fun initLogger(logger: Logger) {
|
61 + | this.logger = logger
|
62 + | }
|
63 + | }
|
64 + |
|
65 + | override fun readAfterExecution(context: ResponseInterceptorContext<Any, Any, HttpRequest?, HttpResponse?>) {
|
66 + | val request = context.request
|
67 + | val response = context.response.getOrNull()
|
68 + |
|
69 + | if (validationEnabled == ValidationEnabled.NEVER) return
|
70 + |
|
71 + | val logger = context.executionContext.coroutineContext.logger<SqsMd5ChecksumValidationInterceptor>()
|
72 + | initLogger(logger)
|
73 + |
|
74 + | if (response != null) {
|
75 + | when (request) {
|
76 + | is SendMessageRequest -> {
|
77 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
78 + |
|
79 + | val sendMessageRequest = request as SendMessageRequest
|
80 + | val sendMessageResponse = response as SendMessageResponse
|
81 + | sendMessageOperationMd5Check(sendMessageRequest, sendMessageResponse)
|
82 + | }
|
83 + |
|
84 + | is ReceiveMessageRequest -> {
|
85 + | if (validationEnabled == ValidationEnabled.WHEN_SENDING) return
|
86 + |
|
87 + | val receiveMessageResponse = response as ReceiveMessageResponse
|
88 + | receiveMessageResultMd5Check(receiveMessageResponse)
|
89 + | }
|
90 + |
|
91 + | is SendMessageBatchRequest -> {
|
92 + | if (validationEnabled == ValidationEnabled.WHEN_RECEIVING) return
|
93 + |
|
94 + | val sendMessageBatchRequest = request as SendMessageBatchRequest
|
95 + | val sendMessageBatchResponse = response as SendMessageBatchResponse
|
96 + | sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResponse)
|
97 + | }
|
98 + | }
|
99 + | }
|
100 + | }
|
101 + |
|
102 + | private fun sendMessageOperationMd5Check(
|
103 + | sendMessageRequest: SendMessageRequest,
|
104 + | sendMessageResponse: SendMessageResponse,
|
105 + | ) {
|
106 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
107 + | val messageBodySent = sendMessageRequest.messageBody
|
108 + |
|
109 + | if (!messageBodySent.isNullOrEmpty()) {
|
110 + | logger.debug { "Validating message body MD5 checksum for SendMessage" }
|
111 + |
|
112 + | val bodyMD5Returned = sendMessageResponse.md5OfMessageBody
|
113 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent)
|
114 + | if (clientSideBodyMd5 != bodyMD5Returned) {
|
115 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMD5Returned")
|
116 + | }
|
117 + | }
|
118 + | }
|
119 + |
|
120 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
121 + | val messageAttrSent = sendMessageRequest.messageAttributes
|
122 + | if (!messageAttrSent.isNullOrEmpty()) {
|
123 + | logger.debug { "Validating message attribute MD5 checksum for SendMessage" }
|
124 + |
|
125 + | val messageAttrMD5Returned = sendMessageResponse.md5OfMessageAttributes
|
126 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
127 + | if (clientSideAttrMd5 != messageAttrMD5Returned) {
|
128 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $messageAttrMD5Returned")
|
129 + | }
|
130 + | }
|
131 + | }
|
132 + |
|
133 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
134 + | val messageSysAttrSent = sendMessageRequest.messageSystemAttributes
|
135 + | if (!messageSysAttrSent.isNullOrEmpty()) {
|
136 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessage" }
|
137 + |
|
138 + | val messageSysAttrMD5Returned = sendMessageResponse.md5OfMessageSystemAttributes
|
139 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
140 + | if (clientSideSysAttrMd5 != messageSysAttrMD5Returned) {
|
141 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideSysAttrMd5 but was $messageSysAttrMD5Returned")
|
142 + | }
|
143 + | }
|
144 + | }
|
145 + | }
|
146 + |
|
147 + | private fun receiveMessageResultMd5Check(receiveMessageResponse: ReceiveMessageResponse) {
|
148 + | val messages = receiveMessageResponse.messages
|
149 + | if (messages != null) {
|
150 + | for (messageReceived in messages) {
|
151 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
152 + | val messageBody = messageReceived.body
|
153 + | if (!messageBody.isNullOrEmpty()) {
|
154 + | logger.debug { "Validating message body MD5 checksum for ReceiveMessage" }
|
155 + |
|
156 + | val bodyMd5Returned = messageReceived.md5OfBody
|
157 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBody)
|
158 + | if (clientSideBodyMd5 != bodyMd5Returned) {
|
159 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMd5Returned")
|
160 + | }
|
161 + | }
|
162 + | }
|
163 + |
|
164 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
165 + | val messageAttr = messageReceived.messageAttributes
|
166 + |
|
167 + | if (!messageAttr.isNullOrEmpty()) {
|
168 + | logger.debug { "Validating message attribute MD5 checksum for ReceiveMessage" }
|
169 + |
|
170 + | val attrMd5Returned = messageReceived.md5OfMessageAttributes
|
171 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr)
|
172 + | if (clientSideAttrMd5 != attrMd5Returned) {
|
173 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $attrMd5Returned")
|
174 + | }
|
175 + | }
|
176 + | }
|
177 + | }
|
178 + | }
|
179 + | }
|
180 + |
|
181 + | private fun sendMessageBatchOperationMd5Check(
|
182 + | sendMessageBatchRequest: SendMessageBatchRequest,
|
183 + | sendMessageBatchResponse: SendMessageBatchResponse,
|
184 + | ) {
|
185 + | val idToRequestEntryMap = hashMapOf<String, SendMessageBatchRequestEntry>()
|
186 + | val entries = sendMessageBatchRequest.entries
|
187 + | if (entries != null) {
|
188 + | for (entry in entries) {
|
189 + | idToRequestEntryMap[entry.id] = entry
|
190 + | }
|
191 + | }
|
192 + |
|
193 + | for (entry in sendMessageBatchResponse.successful) {
|
194 + | if (validationScopes.contains(ValidationScope.MESSAGE_BODY)) {
|
195 + | val messageBody = idToRequestEntryMap[entry.id]?.messageBody
|
196 + |
|
197 + | if (!messageBody.isNullOrEmpty()) {
|
198 + | logger.debug { "Validating message body MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
199 + |
|
200 + | val bodyMd5Returned = entry.md5OfMessageBody
|
201 + | val clientSideBodyMd5 = calculateMessageBodyMd5(messageBody)
|
202 + | if (clientSideBodyMd5 != bodyMd5Returned) {
|
203 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideBodyMd5 but was $bodyMd5Returned")
|
204 + | }
|
205 + | }
|
206 + | }
|
207 + |
|
208 + | if (validationScopes.contains(ValidationScope.MESSAGE_ATTRIBUTES)) {
|
209 + | val messageAttrSent = idToRequestEntryMap[entry.id]?.messageAttributes
|
210 + | if (!messageAttrSent.isNullOrEmpty()) {
|
211 + | logger.debug { "Validating message attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
212 + |
|
213 + | val messageAttrMD5Returned = entry.md5OfMessageAttributes
|
214 + | val clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent)
|
215 + | if (clientSideAttrMd5 != messageAttrMD5Returned) {
|
216 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideAttrMd5 but was $messageAttrMD5Returned")
|
217 + | }
|
218 + | }
|
219 + | }
|
220 + |
|
221 + | if (validationScopes.contains(ValidationScope.MESSAGE_SYSTEM_ATTRIBUTES)) {
|
222 + | val messageSysAttrSent = idToRequestEntryMap[entry.id]?.messageSystemAttributes
|
223 + | if (!messageSysAttrSent.isNullOrEmpty()) {
|
224 + | logger.debug { "Validating message system attribute MD5 checksum for SendMessageBatch: ${entry.messageId}" }
|
225 + |
|
226 + | val messageSysAttrMD5Returned = entry.md5OfMessageSystemAttributes
|
227 + | val clientSideSysAttrMd5 = calculateMessageSystemAttributesMd5(messageSysAttrSent)
|
228 + | if (clientSideSysAttrMd5 != messageSysAttrMD5Returned) {
|
229 + | throw ChecksumMismatchException("Checksum mismatch. Expected $clientSideSysAttrMd5 but was $messageSysAttrMD5Returned")
|
230 + | }
|
231 + | }
|
232 + | }
|
233 + | }
|
234 + | }
|
235 + |
|
236 + | private fun calculateMessageBodyMd5(messageBody: String): String {
|
237 + | val expectedMD5 = try {
|
238 + | messageBody.encodeToByteArray().md5()
|
239 + | } catch (e: Exception) {
|
240 + | throw ClientException(
|
241 + | "Unable to calculate the MD5 hash of the message body." +
|
242 + | "Potential reasons include JVM configuration or FIPS compliance issues." +
|
243 + | "To disable message MD5 validation, you can set checksumValidationEnabled" +
|
244 + | "to false when instantiating the client." + e.message,
|
245 + | )
|
246 + | }
|
247 + | val expectedMD5Hex = expectedMD5.toHexString()
|
248 + | return expectedMD5Hex
|
249 + | }
|
250 + |
|
251 + | /**
|
252 + | * Calculates the MD5 digest for message attributes according to SQS specifications.
|
253 + | * https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html#sqs-attributes-md5-message-digest-calculation
|
254 + | */
|
255 + | @OptIn(InternalApi::class, ExperimentalStdlibApi::class)
|
256 + | private fun calculateMessageAttributesMd5(messageAttributes: Map<String, MessageAttributeValue>): String {
|
257 + | val sortedAttributeNames = messageAttributes.keys.sorted()
|
258 + | val md5Digest = Md5()
|
259 + |
|
260 + | try {
|
261 + | for (attributeName in sortedAttributeNames) {
|
262 + | val attributeValue = messageAttributes[attributeName]
|
263 + | updateLengthAndBytes(md5Digest, attributeName.encodeToByteArray())
|
264 + |
|
265 + | attributeValue?.dataType?.let { dataType ->
|
266 + | updateLengthAndBytes(md5Digest, dataType.encodeToByteArray())
|
267 + | }
|
268 + |
|
269 + | val stringValue = attributeValue?.stringValue
|
270 + | val binaryValue = attributeValue?.binaryValue
|
271 + | val stringListValues = attributeValue?.stringListValues
|
272 + | val binaryListValues = attributeValue?.binaryListValues
|
273 + |
|
274 + | when {
|
275 + | stringValue != null -> updateForStringType(md5Digest, stringValue)
|
276 + | binaryValue != null -> updateForBinaryType(md5Digest, binaryValue)
|
277 + | !stringListValues.isNullOrEmpty() -> updateForStringListType(md5Digest, stringListValues)
|
278 + | !binaryListValues.isNullOrEmpty() -> updateForBinaryListType(md5Digest, binaryListValues)
|
279 + | }
|
280 + | }
|
281 + | } catch (e: Exception) {
|
282 + | throw ClientException(
|
283 + | "Unable to calculate the MD5 hash of the message attributes." +
|
284 + | "Potential reasons include JVM configuration or FIPS compliance issues." +
|
285 + | "To disable message MD5 validation, you can set checksumValidationEnabled" +
|
286 + | "to false when instantiating the client." + e.message,
|
287 + | )
|
288 + | }
|
289 + | val expectedMD5Hex = md5Digest.digest().toHexString()
|
290 + | return expectedMD5Hex
|
291 + | }
|
292 + |
|
293 + | private fun calculateMessageSystemAttributesMd5(
|
294 + | messageSysAttrs: Map<MessageSystemAttributeNameForSends, MessageSystemAttributeValue>,
|
295 + | ): String {
|
296 + | val sortedAttributeNames = messageSysAttrs.keys.sortedBy { it.value }
|
297 + | val md5Digest = Md5()
|
298 + |
|
299 + | try {
|
300 + | for (attributeName in sortedAttributeNames) {
|
301 + | val attributeValue = messageSysAttrs[attributeName]
|
302 + | updateLengthAndBytes(md5Digest, attributeName.value.encodeToByteArray())
|
303 + |
|
304 + | attributeValue?.dataType?.let { dataType ->
|
305 + | updateLengthAndBytes(md5Digest, dataType.encodeToByteArray())
|
306 + | }
|
307 + |
|
308 + | val stringValue = attributeValue?.stringValue
|
309 + | val binaryValue = attributeValue?.binaryValue
|
310 + | val stringListValues = attributeValue?.stringListValues
|
311 + | val binaryListValues = attributeValue?.binaryListValues
|
312 + |
|
313 + | when {
|
314 + | stringValue != null -> updateForStringType(md5Digest, stringValue)
|
315 + | binaryValue != null -> updateForBinaryType(md5Digest, binaryValue)
|
316 + | !stringListValues.isNullOrEmpty() -> updateForStringListType(md5Digest, stringListValues)
|
317 + | !binaryListValues.isNullOrEmpty() -> updateForBinaryListType(md5Digest, binaryListValues)
|
318 + | }
|
319 + | }
|
320 + | } catch (e: Exception) {
|
321 + | throw ClientException(
|
322 + | "Unable to calculate the MD5 hash of the message system attributes." +
|
323 + | "Potential reasons include JVM configuration or FIPS compliance issues." +
|
324 + | "To disable message MD5 validation, you can set checksumValidationEnabled" +
|
325 + | "to false when instantiating the client." + e.message,
|
326 + | )
|
327 + | }
|
328 + | val expectedMD5Hex = md5Digest.digest().toHexString()
|
329 + | return expectedMD5Hex
|
330 + | }
|
331 + |
|
332 + | private fun updateForStringType(md5Digest: Md5, value: String) {
|
333 + | md5Digest.update(STRING_TYPE_FIELD_INDEX)
|
334 + | updateLengthAndBytes(md5Digest, value.encodeToByteArray())
|
335 + | }
|
336 + |
|
337 + | private fun updateForBinaryType(md5Digest: Md5, value: ByteArray) {
|
338 + | md5Digest.update(BINARY_TYPE_FIELD_INDEX)
|
339 + | updateLengthAndBytes(md5Digest, value)
|
340 + | }
|
341 + |
|
342 + | private fun updateForStringListType(md5Digest: Md5, values: List<String>) {
|
343 + | md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX)
|
344 + | values.forEach { value ->
|
345 + | updateLengthAndBytes(md5Digest, value.encodeToByteArray())
|
346 + | }
|
347 + | }
|
348 + |
|
349 + | private fun updateForBinaryListType(md5Digest: Md5, values: List<ByteArray>) {
|
350 + | md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX)
|
351 + | values.forEach { value ->
|
352 + | updateLengthAndBytes(md5Digest, value)
|
353 + | }
|
354 + | }
|
355 + |
|
356 + | /**
|
357 + | * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
|
358 + | * input binaryValue and all the bytes it contains.
|
359 + | */
|
360 + | private fun updateLengthAndBytes(messageDigest: Md5, binaryValue: ByteArray) {
|
361 + | println("updateLengthAndBytes")
|
362 + | val length = binaryValue.size
|
363 + | val lengthBytes = byteArrayOf(
|
364 + | (length shr 24).toByte(),
|
365 + | (length shr 16).toByte(),
|
366 + | (length shr 8).toByte(),
|
367 + | length.toByte(),
|
368 + | )
|
369 + |
|
370 + | messageDigest.update(lengthBytes)
|
371 + | messageDigest.update(binaryValue)
|
372 + | }
|
373 + | }
|