Solution: TokenRepository

class TokenRepository( private val client: TokenClient, private val timeProvider: TimeProvider ) { private var token: Token? = null private val mutex = Mutex() suspend fun getToken(): Token = mutex.withLock { val currentToken = token if (currentToken != null && currentToken.expiration > timeProvider.now()) { return currentToken } val newToken = client.fetchToken() token = newToken return newToken } fun invalidateToken() { token = null } }

Here we must use Mutex in getToken to make sure we make only one call to fetchToken at a time. We must not use this mutex to synchronize invalidateToken, because then cleaning the token would wait for the token to be fetched, and it would invalidate this new token.

We cannot use synchronized block, because we need to suspend instead of blocking the thread. We would use it if fetchToken would be a blocking operation, and we wanted getToken to be a blocking operation as well.

We cannot use atomic reference or a dispatcher limited to a single thread, because we need to make sure that fetchToken is called only once at a time.

Example solution in playground

import kotlinx.coroutines.* import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.currentTime import kotlinx.coroutines.test.runTest import org.junit.Test import java.time.LocalDateTime import java.time.ZoneOffset.UTC import kotlin.test.assertEquals import kotlin.test.assertNotEquals import kotlin.time.Duration.Companion.hours import kotlin.time.Duration.Companion.minutes class TokenRepository( private val client: TokenClient, private val timeProvider: TimeProvider ) { private var token: Token? = null private val mutex = Mutex() suspend fun getToken(): Token = mutex.withLock { val currentToken = token if (currentToken != null && currentToken.expiration > timeProvider.now()) { return currentToken } val newToken = client.fetchToken() token = newToken return newToken } fun invalidateToken() { token = null } } data class Token(val value: String, val expiration: LocalDateTime) interface TimeProvider { fun now(): LocalDateTime } interface TokenClient { suspend fun fetchToken(): Token } class TokenRepositoryTest { @Test fun `should fetch token first, and then not fetch it if available`() = runTest { // given var timesFetchTokenCalled = 0 val token = Token("token", LocalDateTime.now().plusHours(1)) val client = object : TokenClient { override suspend fun fetchToken(): Token { delay(1000) timesFetchTokenCalled++ return token } } val timeProvider = object : TimeProvider { override fun now(): LocalDateTime = LocalDateTime.now() } val repository = TokenRepository(client, timeProvider) // when fetching token three times assertEquals(token, repository.getToken()) assertEquals(token, repository.getToken()) assertEquals(token, repository.getToken()) // then fetchToken should be called only once, and it should take as long as the first call assertEquals(1, timesFetchTokenCalled) assertEquals(1000L, currentTime) } @Test fun `should fetch token if expired`() = runTest { // given var i = 1 val timeProvider = object : TimeProvider { override fun now(): LocalDateTime = LocalDateTime.ofEpochSecond(currentTime / 1000, 0, UTC) } val client = object : TokenClient { override suspend fun fetchToken(): Token = Token("token${i++}", timeProvider.now().plusHours(1)) } val repository = TokenRepository(client, timeProvider) // when fetching token val token = repository.getToken() // then token should have the appropriate expiration date assertEquals(currentTime / 1000 + 1.hours.inWholeSeconds, token.expiration.toEpochSecond(UTC)) // when fetching token after expiration delay(1.hours + 5.minutes) val newToken = repository.getToken() // then it is a new token assertNotEquals(token, newToken) // then new token should have the appropriate expiration date assertEquals(currentTime / 1000 + 1.hours.inWholeSeconds, newToken.expiration.toEpochSecond(UTC)) } @Test fun `should generate only one token for multiple requests`() = runTest { // given var i = 1 var fetchTokenCalled = 0 val timeProvider = object : TimeProvider { override fun now(): LocalDateTime = LocalDateTime.now() } val client = object : TokenClient { override suspend fun fetchToken(): Token { fetchTokenCalled++ delay(1000) return Token("token${i++}", timeProvider.now().plusHours(1)) } } val repository = TokenRepository(client, timeProvider) // when fetching token by multiple coroutines concurrently val tokens = coroutineScope { (1..10).map { async { repository.getToken() } } .awaitAll() } // then all tokens should be the same assertEquals(1, tokens.toSet().size) // and fetchToken should be called only once assertEquals(1, fetchTokenCalled) // and it should take as long as the first call assertEquals(1000L, currentTime) } @Test fun `should provide new token after invalidation`() = runTest { // given var i = 1 val timeProvider = object : TimeProvider { override fun now(): LocalDateTime = LocalDateTime.now() } val client = object : TokenClient { override suspend fun fetchToken(): Token = Token("token${i++}", timeProvider.now().plusHours(1)) } val repository = TokenRepository(client, timeProvider) // when fetching token val token = repository.getToken() // then token should be the same assertEquals(token, repository.getToken()) // when invalidating token repository.invalidateToken() // then new token should be different assertNotEquals(token, repository.getToken()) } @Test fun `should not invalidate token that is currently fetching`() = runTest { // given var i = 1 val timeProvider = object : TimeProvider { override fun now(): LocalDateTime = LocalDateTime.now() } val client = object : TokenClient { override suspend fun fetchToken(): Token { delay(1000) return Token("token${i++}", timeProvider.now().plusHours(1)) } } val repository = TokenRepository(client, timeProvider) // when fetching token val tokenAsync = async { repository.getToken() } // and in-between invalidating token delay(500) repository.invalidateToken() // when fetching token again val token2 = repository.getToken() or fetching later val token = tokenAsync.await() val token3 = repository.getToken() // then token should be the same assertEquals(token, token2) assertEquals(token, token3) } }