|     9      9    | import aws.sdk.kotlin.services.s3.S3Client
 | 
      |    10     10    | import aws.sdk.kotlin.services.s3.model.CreateSessionRequest
 | 
      |    11     11    | import aws.sdk.kotlin.services.s3.model.CreateSessionResponse
 | 
      |    12     12    | import aws.sdk.kotlin.services.s3.model.SessionCredentials
 | 
      |    13     13    | import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials
 | 
      |    14     14    | import aws.smithy.kotlin.runtime.io.use
 | 
      |    15     15    | import aws.smithy.kotlin.runtime.operation.ExecutionContext
 | 
      |    16     16    | import aws.smithy.kotlin.runtime.time.ManualClock
 | 
      |    17     17    | import kotlinx.coroutines.*
 | 
      |    18     18    | import kotlinx.coroutines.test.runTest
 | 
      |    19         - | import kotlin.test.*
 | 
      |           19  + | import kotlin.test.Test
 | 
      |           20  + | import kotlin.test.assertEquals
 | 
      |           21  + | import kotlin.test.assertFalse
 | 
      |    20     22    | import kotlin.time.ComparableTimeMark
 | 
      |    21     23    | import kotlin.time.Duration.Companion.milliseconds
 | 
      |    22     24    | import kotlin.time.Duration.Companion.minutes
 | 
      |    23     25    | import kotlin.time.Duration.Companion.seconds
 | 
      |    24     26    | import kotlin.time.TestTimeSource
 | 
      |    25     27    | 
 | 
      |    26     28    | private val DEFAULT_BASE_CREDENTIALS = Credentials("accessKeyId", "secretAccessKey", "sessionToken")
 | 
      |    27     29    | 
 | 
      |    28     30    | class DefaultS3ExpressCredentialsProviderTest {
 | 
      |    29     31    |     @Test
 | 
      |    30     32    |     fun testCreateSessionCredentials() = runTest {
 | 
      |    31     33    |         val timeSource = TestTimeSource()
 | 
      |    32     34    |         val clock = ManualClock()
 | 
      |    33     35    | 
 | 
      |    34     36    |         val expectedCredentials = SessionCredentials {
 | 
      |    35     37    |             accessKeyId = "access"
 | 
      |    36     38    |             secretAccessKey = "secret"
 | 
      |    37     39    |             sessionToken = "session"
 | 
      |    38     40    |             expiration = clock.now() + 5.minutes
 | 
      |    39     41    |         }
 | 
      |    40     42    | 
 | 
      |    41         - |         val client = TestS3Client(expectedCredentials)
 | 
      |    42         - | 
 | 
      |           43  + |         TestS3Client(expectedCredentials).use { client ->
 | 
      |    43     44    |             DefaultS3ExpressCredentialsProvider(timeSource, clock).use { provider ->
 | 
      |    44     45    |                 val credentials = provider.createSessionCredentials(
 | 
      |    45     46    |                     S3ExpressCredentialsCacheKey("bucket", DEFAULT_BASE_CREDENTIALS),
 | 
      |    46     47    |                     client,
 | 
      |    47     48    |                 )
 | 
      |    48     49    |                 assertFalse(credentials.isExpired)
 | 
      |    49     50    |                 assertEquals(timeSource.markNow() + 5.minutes, credentials.expiresAt)
 | 
      |    50     51    |             }
 | 
      |    51     52    |         }
 | 
      |           53  + |     }
 | 
      |    52     54    | 
 | 
      |    53     55    |     @Test
 | 
      |    54     56    |     fun testSyncRefresh() = runTest {
 | 
      |    55     57    |         val timeSource = TestTimeSource()
 | 
      |    56     58    |         val clock = ManualClock()
 | 
      |    57     59    | 
 | 
      |    58     60    |         // Entry expired 30 seconds ago, next `resolve` call should trigger a sync refresh
 | 
      |    59     61    |         val cache = S3ExpressCredentialsCache()
 | 
      |    60     62    |         val entry = getCacheEntry(timeSource.markNow() - 30.seconds)
 | 
      |    61     63    |         cache.put(entry.key, entry.value)
 | 
      |    62     64    | 
 | 
      |    63     65    |         val expectedCredentials = SessionCredentials {
 | 
      |    64     66    |             accessKeyId = "access"
 | 
      |    65     67    |             secretAccessKey = "secret"
 | 
      |    66     68    |             sessionToken = "session"
 | 
      |    67     69    |             expiration = clock.now() + 5.minutes
 | 
      |    68     70    |         }
 | 
      |    69     71    | 
 | 
      |    70         - |         val testClient = TestS3Client(expectedCredentials)
 | 
      |           72  + |         TestS3Client(expectedCredentials).use { testClient ->
 | 
      |    71     73    |             DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes).use { provider ->
 | 
      |    72     74    |                 val attributes = ExecutionContext.build {
 | 
      |    73     75    |                     this.attributes[S3Attributes.ExpressClient] = testClient
 | 
      |    74     76    |                     this.attributes[S3Attributes.Bucket] = "bucket"
 | 
      |    75     77    |                 }
 | 
      |    76     78    | 
 | 
      |    77     79    |                 provider.resolve(attributes)
 | 
      |    78     80    |             }
 | 
      |    79     81    |             assertEquals(1, testClient.numCreateSession)
 | 
      |    80     82    |         }
 | 
      |           83  + |     }
 | 
      |    81     84    | 
 | 
      |    82     85    |     @Test
 | 
      |    83     86    |     fun testAsyncRefresh() = runTest {
 | 
      |    84     87    |         val timeSource = TestTimeSource()
 | 
      |    85     88    |         val clock = ManualClock()
 | 
      |    86     89    | 
 | 
      |    87     90    |         // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
 | 
      |    88     91    |         val cache = S3ExpressCredentialsCache()
 | 
      |    89     92    |         val entry = getCacheEntry(timeSource.markNow() + 30.seconds)
 | 
      |    90     93    |         cache.put(entry.key, entry.value)
 | 
      |    91     94    | 
 | 
      |    92     95    |         val expectedCredentials = SessionCredentials {
 | 
      |    93     96    |             accessKeyId = "access"
 | 
      |    94     97    |             secretAccessKey = "secret"
 | 
      |    95     98    |             sessionToken = "session"
 | 
      |    96     99    |             expiration = clock.now() + 5.minutes
 | 
      |    97    100    |         }
 | 
      |    98    101    | 
 | 
      |    99         - |         val testClient = TestS3Client(expectedCredentials)
 | 
      |   100         - | 
 | 
      |          102  + |         TestS3Client(expectedCredentials).use { testClient ->
 | 
      |   101    103    |             val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
 | 
      |   102    104    | 
 | 
      |   103    105    |             val attributes = ExecutionContext.build {
 | 
      |   104    106    |                 this.attributes[S3Attributes.ExpressClient] = testClient
 | 
      |   105    107    |                 this.attributes[S3Attributes.Bucket] = "bucket"
 | 
      |   106    108    |             }
 | 
      |   107    109    |             provider.resolve(attributes)
 | 
      |   108    110    | 
 | 
      |   109    111    |             // allow the async refresh to initiate before closing the provider
 | 
      |   110    112    |             runBlocking { delay(500.milliseconds) }
 | 
      |   111    113    | 
 | 
      |   112    114    |             provider.close()
 | 
      |   113    115    |             provider.coroutineContext.job.join()
 | 
      |   114    116    | 
 | 
      |   115    117    |             assertEquals(1, testClient.numCreateSession)
 | 
      |   116    118    |         }
 | 
      |          119  + |     }
 | 
      |   117    120    | 
 | 
      |   118    121    |     @Test
 | 
      |   119    122    |     fun testAsyncRefreshDebounce() = runTest {
 | 
      |   120    123    |         val timeSource = TestTimeSource()
 | 
      |   121    124    |         val clock = ManualClock()
 | 
      |   122    125    | 
 | 
      |   123    126    |         // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
 | 
      |   124    127    |         val cache = S3ExpressCredentialsCache()
 | 
      |   125    128    |         val entry = getCacheEntry(expiration = timeSource.markNow() + 30.seconds)
 | 
      |   126    129    |         cache.put(entry.key, entry.value)
 | 
      |   127    130    | 
 | 
      |   128    131    |         val expectedCredentials = SessionCredentials {
 | 
      |   129    132    |             accessKeyId = "access"
 | 
      |   130    133    |             secretAccessKey = "secret"
 | 
      |   131    134    |             sessionToken = "session"
 | 
      |   132    135    |             expiration = clock.now() + 5.minutes
 | 
      |   133    136    |         }
 | 
      |   134    137    | 
 | 
      |   135         - |         val testClient = TestS3Client(expectedCredentials)
 | 
      |   136         - | 
 | 
      |          138  + |         TestS3Client(expectedCredentials).use { testClient ->
 | 
      |   137    139    |             val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
 | 
      |   138    140    | 
 | 
      |   139    141    |             val attributes = ExecutionContext.build {
 | 
      |   140    142    |                 this.attributes[S3Attributes.ExpressClient] = testClient
 | 
      |   141    143    |                 this.attributes[S3Attributes.Bucket] = "bucket"
 | 
      |   142    144    |             }
 | 
      |   143    145    |             val calls = (1..5).map {
 | 
      |   144    146    |                 async { provider.resolve(attributes) }
 | 
      |   145    147    |             }
 | 
      |   146    148    |             calls.awaitAll()
 | 
      |   147    149    | 
 | 
      |   148    150    |             // allow the async refresh to initiate before closing the provider
 | 
      |   149    151    |             runBlocking { delay(500.milliseconds) }
 | 
      |   150    152    | 
 | 
      |   151    153    |             provider.close()
 | 
      |   152    154    |             provider.coroutineContext.job.join()
 | 
      |   153    155    | 
 | 
      |   154    156    |             assertEquals(1, testClient.numCreateSession)
 | 
      |   155    157    |         }
 | 
      |          158  + |     }
 | 
      |   156    159    | 
 | 
      |   157    160    |     @Test
 | 
      |   158    161    |     fun testAsyncRefreshHandlesFailures() = runTest {
 | 
      |   159    162    |         val timeSource = TestTimeSource()
 | 
      |   160    163    |         val clock = ManualClock()
 | 
      |   161    164    | 
 | 
      |   162    165    |         // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
 | 
      |   163    166    |         val cache = S3ExpressCredentialsCache()
 | 
      |   164    167    |         val successEntry = getCacheEntry(timeSource.markNow() + 30.seconds, bucket = "SuccessfulBucket")
 | 
      |   165    168    |         val failedEntry = getCacheEntry(timeSource.markNow() + 30.seconds, bucket = "ExceptionBucket")
 | 
      |   166    169    |         cache.put(successEntry.key, successEntry.value)
 | 
      |   167    170    |         cache.put(failedEntry.key, failedEntry.value)
 | 
      |   168    171    | 
 | 
      |   169    172    |         val expectedCredentials = SessionCredentials {
 | 
      |   170    173    |             accessKeyId = "access"
 | 
      |   171    174    |             secretAccessKey = "secret"
 | 
      |   172    175    |             sessionToken = "session"
 | 
      |   173    176    |             expiration = clock.now() + 5.minutes
 | 
      |   174    177    |         }
 | 
      |   175    178    | 
 | 
      |   176    179    |         // client should throw an exception when `ExceptionBucket` credentials are fetched, but it should be caught
 | 
      |   177         - |         val testClient = TestS3Client(expectedCredentials, throwExceptionOnBucketNamed = "ExceptionBucket")
 | 
      |   178         - | 
 | 
      |          180  + |         TestS3Client(expectedCredentials, throwExceptionOnBucketNamed = "ExceptionBucket").use { testClient ->
 | 
      |   179    181    |             val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
 | 
      |   180    182    |             val attributes = ExecutionContext.build {
 | 
      |   181    183    |                 this.attributes[S3Attributes.ExpressClient] = testClient
 | 
      |   182    184    |                 this.attributes[S3Attributes.Bucket] = "ExceptionBucket"
 | 
      |   183    185    |             }
 | 
      |   184    186    |             provider.resolve(attributes)
 | 
      |   185    187    | 
 | 
      |   186    188    |             withTimeout(5.seconds) {
 | 
      |   187    189    |                 while (testClient.numCreateSession != 1) {
 | 
      |   188    190    |                     yield()
 | 
      |   189    191    |                 }
 | 
      |   190    192    |             }
 | 
      |   191    193    |             assertEquals(1, testClient.numCreateSession)
 | 
      |   192    194    | 
 | 
      |   193    195    |             attributes[S3Attributes.Bucket] = "SuccessfulBucket"
 | 
      |   194    196    |             provider.resolve(attributes)
 | 
      |   195    197    | 
 | 
      |   196    198    |             withTimeout(5.seconds) {
 | 
      |   197    199    |                 while (testClient.numCreateSession != 2) {
 | 
      |   198    200    |                     yield()
 | 
      |   199    201    |                 }
 | 
      |   200    202    |             }
 | 
      |   201    203    | 
 | 
      |   202    204    |             provider.close()
 | 
      |   203    205    |             provider.coroutineContext.job.join()
 | 
      |   204    206    | 
 | 
      |   205    207    |             assertEquals(2, testClient.numCreateSession)
 | 
      |   206    208    |         }
 | 
      |          209  + |     }
 | 
      |   207    210    | 
 | 
      |   208    211    |     @Test
 | 
      |   209    212    |     fun testAsyncRefreshClosesImmediately() = runTest {
 | 
      |   210    213    |         val timeSource = TestTimeSource()
 | 
      |   211    214    |         val clock = ManualClock()
 | 
      |   212    215    | 
 | 
      |   213    216    |         // Entry expires in 30 seconds, refresh buffer is 1 minute. Next `resolve` call should trigger the async refresh
 | 
      |   214    217    |         val cache = S3ExpressCredentialsCache()
 | 
      |   215    218    |         val entry = getCacheEntry(timeSource.markNow() + 30.seconds)
 | 
      |   216    219    |         cache.put(entry.key, entry.value)
 | 
      |   217    220    | 
 | 
      |   218    221    |         val expectedCredentials = SessionCredentials {
 | 
      |   219    222    |             accessKeyId = "access"
 | 
      |   220    223    |             secretAccessKey = "secret"
 | 
      |   221    224    |             sessionToken = "session"
 | 
      |   222    225    |             expiration = clock.now() + 5.minutes
 | 
      |   223    226    |         }
 | 
      |   224    227    | 
 | 
      |   225    228    |         val provider = DefaultS3ExpressCredentialsProvider(timeSource, clock, cache, refreshBuffer = 1.minutes)
 | 
      |   226    229    | 
 | 
      |   227         - |         val blockingTestS3Client = object : TestS3Client(expectedCredentials) {
 | 
      |          230  + |         class BlockingTestS3Client : TestS3Client(expectedCredentials) {
 | 
      |   228    231    |             override suspend fun createSession(input: CreateSessionRequest): CreateSessionResponse {
 | 
      |   229    232    |                 delay(10.seconds)
 | 
      |   230    233    |                 numCreateSession += 1
 | 
      |   231    234    |                 return CreateSessionResponse { credentials = expectedCredentials }
 | 
      |   232    235    |             }
 | 
      |   233    236    |         }
 | 
      |   234    237    | 
 | 
      |          238  + |         BlockingTestS3Client().use { blockingTestS3Client ->
 | 
      |   235    239    |             val attributes = ExecutionContext.build {
 | 
      |   236    240    |                 this.attributes[S3Attributes.ExpressClient] = blockingTestS3Client
 | 
      |   237    241    |                 this.attributes[S3Attributes.Bucket] = "bucket"
 | 
      |   238    242    |             }
 | 
      |   239    243    | 
 | 
      |   240    244    |             withTimeout(5.seconds) {
 | 
      |   241    245    |                 provider.resolve(attributes)
 | 
      |   242    246    |                 provider.close()
 | 
      |   243    247    |             }
 | 
      |   244    248    |             assertEquals(0, blockingTestS3Client.numCreateSession)
 | 
      |   245    249    |         }
 | 
      |          250  + |     }
 | 
      |   246    251    | 
 | 
      |   247    252    |     /**
 | 
      |   248    253    |      * Get an instance of [Map.Entry<S3ExpressCredentialsCacheKey, S3ExpressCredentialsCacheValue>] using the given [expiration],
 | 
      |   249    254    |      * [bucket], and optional [bootstrapCredentials] and [sessionCredentials].
 | 
      |   250    255    |      */
 | 
      |   251    256    |     private fun getCacheEntry(
 | 
      |   252    257    |         expiration: ComparableTimeMark,
 | 
      |   253    258    |         bucket: String = "bucket",
 | 
      |   254    259    |         bootstrapCredentials: Credentials = Credentials(accessKeyId = "accessKeyId", secretAccessKey = "secretAccessKey", sessionToken = "sessionToken"),
 | 
      |   255    260    |         sessionCredentials: Credentials = Credentials(accessKeyId = "s3AccessKeyId", secretAccessKey = "s3SecretAccessKey", sessionToken = "s3SessionToken"),
 | 
      |   256    261    |     ): S3ExpressCredentialsCacheEntry = mapOf(
 | 
      |   257    262    |         S3ExpressCredentialsCacheKey(bucket, bootstrapCredentials) to S3ExpressCredentialsCacheValue(ExpiringValue(sessionCredentials, expiration)),
 | 
      |   258    263    |     ).entries.first()
 | 
      |   259    264    | 
 | 
      |   260    265    |     /**
 | 
      |   261    266    |      * A test S3Client used to mock calls to s3:CreateSession.
 | 
      |   262    267    |      * @param expectedCredentials the expected session credentials returned from s3:CreateSession
 | 
      |   263    268    |      * @param client the base S3 client used to implement other operations, though they are unused.
 | 
      |   264    269    |      * @param throwExceptionOnBucketNamed an optional bucket name, which when specified and present in the [CreateSessionRequest], will
 | 
      |   265    270    |      * cause the client to throw an exception instead of returning credentials. Used for testing s3:CreateSession failures.
 | 
      |   266    271    |      */
 | 
      |   267    272    |     private open inner class TestS3Client(
 | 
      |   268    273    |         val expectedCredentials: SessionCredentials,
 | 
      |   269    274    |         val baseCredentials: Credentials = DEFAULT_BASE_CREDENTIALS,
 | 
      |   270    275    |         val client: S3Client = S3Client { credentialsProvider = StaticCredentialsProvider(baseCredentials) },
 | 
      |   271    276    |         val throwExceptionOnBucketNamed: String? = null,
 | 
      |   272    277    |     ) : S3Client by client {
 | 
      |   273    278    |         var numCreateSession = 0
 | 
      |   274    279    | 
 | 
      |   275    280    |         override suspend fun createSession(input: CreateSessionRequest): CreateSessionResponse {
 | 
      |   276    281    |             numCreateSession += 1
 | 
      |   277    282    |             throwExceptionOnBucketNamed?.let {
 | 
      |   278    283    |                 if (input.bucket == it) {
 | 
      |   279    284    |                     throw Exception("Failed to create session credentials for bucket: $throwExceptionOnBucketNamed")
 | 
      |   280    285    |                 }
 | 
      |   281    286    |             }
 | 
      |   282    287    |             return CreateSessionResponse { credentials = expectedCredentials }
 | 
      |   283    288    |         }
 | 
      |          289  + | 
 | 
      |          290  + |         override fun close() {
 | 
      |          291  + |             client.close()
 | 
      |          292  + |         }
 | 
      |   284    293    |     }
 | 
      |   285    294    | }
 |