9 9 | import aws.sdk.kotlin.services.s3.S3Client
|
10 10 | import aws.sdk.kotlin.services.s3.model.CreateSessionRequest
|
11 11 | import aws.sdk.kotlin.services.s3.model.CreateSessionResponse
|
12 12 | import aws.sdk.kotlin.services.s3.model.SessionCredentials
|
13 13 | import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
|
14 14 | import aws.smithy.kotlin.runtime.io.use
|
15 15 | import aws.smithy.kotlin.runtime.operation.ExecutionContext
|
16 16 | import aws.smithy.kotlin.runtime.time.ManualClock
|
17 17 | import kotlinx.coroutines.*
|
18 18 | import kotlinx.coroutines.test.runTest
|
19 - | import kotlin.test.*
|
19 + | import kotlin.test.Test
|
20 + | import kotlin.test.assertEquals
|
21 + | import kotlin.test.assertFalse
|
20 22 | import kotlin.time.ComparableTimeMark
|
21 23 | import kotlin.time.Duration.Companion.milliseconds
|
22 24 | import kotlin.time.Duration.Companion.minutes
|
23 25 | import kotlin.time.Duration.Companion.seconds
|
24 26 | import kotlin.time.TestTimeSource
|
25 27 |
|
26 28 | private val DEFAULT_BASE_CREDENTIALS = Credentials("accessKeyId", "secretAccessKey", "sessionToken")
|
27 29 |
|
28 30 | class DefaultS3ExpressCredentialsProviderTest {
|
29 31 | @Test
|
30 32 | fun testCreateSessionCredentials() = runTest {
|
31 33 | val timeSource = TestTimeSource()
|
32 34 | val clock = ManualClock()
|
33 35 |
|
34 36 | val expectedCredentials = SessionCredentials {
|
35 37 | accessKeyId = "access"
|
36 38 | secretAccessKey = "secret"
|
37 39 | sessionToken = "session"
|
38 40 | expiration = clock.now() + 5.minutes
|
39 41 | }
|
40 42 |
|
41 - | val client = TestS3Client(expectedCredentials)
|
42 - |
|
43 + | TestS3Client(expectedCredentials).use { client ->
|
43 44 | DefaultS3ExpressCredentialsProvider(timeSource, clock).use { provider ->
|
44 45 | val credentials = provider.createSessionCredentials(
|
45 46 | S3ExpressCredentialsCacheKey("bucket", DEFAULT_BASE_CREDENTIALS),
|
46 47 | client,
|
47 48 | )
|
48 49 | assertFalse(credentials.isExpired)
|
49 50 | assertEquals(timeSource.markNow() + 5.minutes, credentials.expiresAt)
|
50 51 | }
|
51 52 | }
|
53 + | }
|
52 54 |
|
53 55 | @Test
|
54 56 | fun testSyncRefresh() = runTest {
|
55 57 | val timeSource = TestTimeSource()
|
56 58 | val clock = ManualClock()
|
57 59 |
|
58 60 | // Entry expired 30 seconds ago, next `resolve` call should trigger a sync refresh
|
59 61 | val cache = S3ExpressCredentialsCache()
|
60 62 | val entry = getCacheEntry(timeSource.markNow() - 30.seconds)
|
61 63 | cache.put(entry.key, entry.value)
|
62 64 |
|
63 65 | val expectedCredentials = SessionCredentials {
|
64 66 | accessKeyId = "access"
|
65 67 | secretAccessKey = "secret"
|
66 68 | sessionToken = "session"
|
67 69 | expiration = clock.now() + 5.minutes
|
68 70 | }
|
69 71 |
|
70 - | val testClient = TestS3Client(expectedCredentials)
|
72 + | TestS3Client(expectedCredentials).use { testClient ->
|
71 73 | DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes).use { provider ->
|
72 74 | val attributes = ExecutionContext.build {
|
73 75 | this.attributes[S3Attributes.ExpressClient] = testClient
|
74 76 | this.attributes[S3Attributes.Bucket] = "bucket"
|
75 77 | }
|
76 78 |
|
77 79 | provider.resolve(attributes)
|
78 80 | }
|
79 81 | assertEquals(1, testClient.numCreateSession)
|
80 82 | }
|
83 + | }
|
81 84 |
|
82 85 | @Test
|
83 86 | fun testAsyncRefresh() = runTest {
|
84 87 | val timeSource = TestTimeSource()
|
85 88 | val clock = ManualClock()
|
86 89 |
|
87 90 | // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
|
88 91 | val cache = S3ExpressCredentialsCache()
|
89 92 | val entry = getCacheEntry(timeSource.markNow() + 30.seconds)
|
90 93 | cache.put(entry.key, entry.value)
|
91 94 |
|
92 95 | val expectedCredentials = SessionCredentials {
|
93 96 | accessKeyId = "access"
|
94 97 | secretAccessKey = "secret"
|
95 98 | sessionToken = "session"
|
96 99 | expiration = clock.now() + 5.minutes
|
97 100 | }
|
98 101 |
|
99 - | val testClient = TestS3Client(expectedCredentials)
|
100 - |
|
102 + | TestS3Client(expectedCredentials).use { testClient ->
|
101 103 | val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
|
102 104 |
|
103 105 | val attributes = ExecutionContext.build {
|
104 106 | this.attributes[S3Attributes.ExpressClient] = testClient
|
105 107 | this.attributes[S3Attributes.Bucket] = "bucket"
|
106 108 | }
|
107 109 | provider.resolve(attributes)
|
108 110 |
|
109 111 | // allow the async refresh to initiate before closing the provider
|
110 112 | runBlocking { delay(500.milliseconds) }
|
111 113 |
|
112 114 | provider.close()
|
113 115 | provider.coroutineContext.job.join()
|
114 116 |
|
115 117 | assertEquals(1, testClient.numCreateSession)
|
116 118 | }
|
119 + | }
|
117 120 |
|
118 121 | @Test
|
119 122 | fun testAsyncRefreshDebounce() = runTest {
|
120 123 | val timeSource = TestTimeSource()
|
121 124 | val clock = ManualClock()
|
122 125 |
|
123 126 | // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
|
124 127 | val cache = S3ExpressCredentialsCache()
|
125 128 | val entry = getCacheEntry(expiration = timeSource.markNow() + 30.seconds)
|
126 129 | cache.put(entry.key, entry.value)
|
127 130 |
|
128 131 | val expectedCredentials = SessionCredentials {
|
129 132 | accessKeyId = "access"
|
130 133 | secretAccessKey = "secret"
|
131 134 | sessionToken = "session"
|
132 135 | expiration = clock.now() + 5.minutes
|
133 136 | }
|
134 137 |
|
135 - | val testClient = TestS3Client(expectedCredentials)
|
136 - |
|
138 + | TestS3Client(expectedCredentials).use { testClient ->
|
137 139 | val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
|
138 140 |
|
139 141 | val attributes = ExecutionContext.build {
|
140 142 | this.attributes[S3Attributes.ExpressClient] = testClient
|
141 143 | this.attributes[S3Attributes.Bucket] = "bucket"
|
142 144 | }
|
143 145 | val calls = (1..5).map {
|
144 146 | async { provider.resolve(attributes) }
|
145 147 | }
|
146 148 | calls.awaitAll()
|
147 149 |
|
148 150 | // allow the async refresh to initiate before closing the provider
|
149 151 | runBlocking { delay(500.milliseconds) }
|
150 152 |
|
151 153 | provider.close()
|
152 154 | provider.coroutineContext.job.join()
|
153 155 |
|
154 156 | assertEquals(1, testClient.numCreateSession)
|
155 157 | }
|
158 + | }
|
156 159 |
|
157 160 | @Test
|
158 161 | fun testAsyncRefreshHandlesFailures() = runTest {
|
159 162 | val timeSource = TestTimeSource()
|
160 163 | val clock = ManualClock()
|
161 164 |
|
162 165 | // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
|
163 166 | val cache = S3ExpressCredentialsCache()
|
164 167 | val successEntry = getCacheEntry(timeSource.markNow() + 30.seconds, bucket = "SuccessfulBucket")
|
165 168 | val failedEntry = getCacheEntry(timeSource.markNow() + 30.seconds, bucket = "ExceptionBucket")
|
166 169 | cache.put(successEntry.key, successEntry.value)
|
167 170 | cache.put(failedEntry.key, failedEntry.value)
|
168 171 |
|
169 172 | val expectedCredentials = SessionCredentials {
|
170 173 | accessKeyId = "access"
|
171 174 | secretAccessKey = "secret"
|
172 175 | sessionToken = "session"
|
173 176 | expiration = clock.now() + 5.minutes
|
174 177 | }
|
175 178 |
|
176 179 | // client should throw an exception when `ExceptionBucket` credentials are fetched, but it should be caught
|
177 - | val testClient = TestS3Client(expectedCredentials, throwExceptionOnBucketNamed = "ExceptionBucket")
|
178 - |
|
180 + | TestS3Client(expectedCredentials, throwExceptionOnBucketNamed = "ExceptionBucket").use { testClient ->
|
179 181 | val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
|
180 182 | val attributes = ExecutionContext.build {
|
181 183 | this.attributes[S3Attributes.ExpressClient] = testClient
|
182 184 | this.attributes[S3Attributes.Bucket] = "ExceptionBucket"
|
183 185 | }
|
184 186 | provider.resolve(attributes)
|
185 187 |
|
186 188 | withTimeout(5.seconds) {
|
187 189 | while (testClient.numCreateSession != 1) {
|
188 190 | yield()
|
189 191 | }
|
190 192 | }
|
191 193 | assertEquals(1, testClient.numCreateSession)
|
192 194 |
|
193 195 | attributes[S3Attributes.Bucket] = "SuccessfulBucket"
|
194 196 | provider.resolve(attributes)
|
195 197 |
|
196 198 | withTimeout(5.seconds) {
|
197 199 | while (testClient.numCreateSession != 2) {
|
198 200 | yield()
|
199 201 | }
|
200 202 | }
|
201 203 |
|
202 204 | provider.close()
|
203 205 | provider.coroutineContext.job.join()
|
204 206 |
|
205 207 | assertEquals(2, testClient.numCreateSession)
|
206 208 | }
|
209 + | }
|
207 210 |
|
208 211 | @Test
|
209 212 | fun testAsyncRefreshClosesImmediately() = runTest {
|
210 213 | val timeSource = TestTimeSource()
|
211 214 | val clock = ManualClock()
|
212 215 |
|
213 216 | // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
|
214 217 | val cache = S3ExpressCredentialsCache()
|
215 218 | val entry = getCacheEntry(timeSource.markNow() + 30.seconds)
|
216 219 | cache.put(entry.key, entry.value)
|
217 220 |
|
218 221 | val expectedCredentials = SessionCredentials {
|
219 222 | accessKeyId = "access"
|
220 223 | secretAccessKey = "secret"
|
221 224 | sessionToken = "session"
|
222 225 | expiration = clock.now() + 5.minutes
|
223 226 | }
|
224 227 |
|
225 228 | val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
|
226 229 |
|
227 - | val blockingTestS3Client = object : TestS3Client(expectedCredentials) {
|
230 + | class BlockingTestS3Client : TestS3Client(expectedCredentials) {
|
228 231 | override suspend fun createSession(input: CreateSessionRequest): CreateSessionResponse {
|
229 232 | delay(10.seconds)
|
230 233 | numCreateSession += 1
|
231 234 | return CreateSessionResponse { credentials = expectedCredentials }
|
232 235 | }
|
233 236 | }
|
234 237 |
|
238 + | BlockingTestS3Client().use { blockingTestS3Client ->
|
235 239 | val attributes = ExecutionContext.build {
|
236 240 | this.attributes[S3Attributes.ExpressClient] = blockingTestS3Client
|
237 241 | this.attributes[S3Attributes.Bucket] = "bucket"
|
238 242 | }
|
239 243 |
|
240 244 | withTimeout(5.seconds) {
|
241 245 | provider.resolve(attributes)
|
242 246 | provider.close()
|
243 247 | }
|
244 248 | assertEquals(0, blockingTestS3Client.numCreateSession)
|
245 249 | }
|
250 + | }
|
246 251 |
|
247 252 | /**
|
248 253 | * Get an instance of [Map.Entry<S3ExpressCredentialsCacheKey, S3ExpressCredentialsCacheValue>] using the given [expiration],
|
249 254 | * [bucket], and optional [bootstrapCredentials] and [sessionCredentials].
|
250 255 | */
|
251 256 | private fun getCacheEntry(
|
252 257 | expiration: ComparableTimeMark,
|
253 258 | bucket: String = "bucket",
|
254 259 | bootstrapCredentials: Credentials = Credentials(accessKeyId = "accessKeyId", secretAccessKey = "secretAccessKey", sessionToken = "sessionToken"),
|
255 260 | sessionCredentials: Credentials = Credentials(accessKeyId = "s3AccessKeyId", secretAccessKey = "s3SecretAccessKey", sessionToken = "s3SessionToken"),
|
256 261 | ): S3ExpressCredentialsCacheEntry = mapOf(
|
257 262 | S3ExpressCredentialsCacheKey(bucket, bootstrapCredentials) to S3ExpressCredentialsCacheValue(ExpiringValue(sessionCredentials, expiration)),
|
258 263 | ).entries.first()
|
259 264 |
|
260 265 | /**
|
261 266 | * A test S3Client used to mock calls to s3:CreateSession.
|
262 267 | * @param expectedCredentials the expected session credentials returned from s3:CreateSession
|
263 268 | * @param client the base S3 client used to implement other operations, though they are unused.
|
264 269 | * @param throwExceptionOnBucketNamed an optional bucket name, which when specified and present in the [CreateSessionRequest], will
|
265 270 | * cause the client to throw an exception instead of returning credentials. Used for testing s3:CreateSession failures.
|
266 271 | */
|
267 272 | private open inner class TestS3Client(
|
268 273 | val expectedCredentials: SessionCredentials,
|
269 274 | val baseCredentials: Credentials = DEFAULT_BASE_CREDENTIALS,
|
270 275 | val client: S3Client = S3Client { credentialsProvider = StaticCredentialsProvider(baseCredentials) },
|
271 276 | val throwExceptionOnBucketNamed: String? = null,
|
272 277 | ) : S3Client by client {
|
273 278 | var numCreateSession = 0
|
274 279 |
|
275 280 | override suspend fun createSession(input: CreateSessionRequest): CreateSessionResponse {
|
276 281 | numCreateSession += 1
|
277 282 | throwExceptionOnBucketNamed?.let {
|
278 283 | if (input.bucket == it) {
|
279 284 | throw Exception("Failed to create session credentials for bucket: $throwExceptionOnBucketNamed")
|
280 285 | }
|
281 286 | }
|
282 287 | return CreateSessionResponse { credentials = expectedCredentials }
|
283 288 | }
|
289 + |
|
290 + | override fun close() {
|
291 + | client.close()
|
292 + | }
|
284 293 | }
|
285 294 | }
|