Solution: ProductService

This is a simple solution:

class ProductService( private val productRepository: ProductRepository, backgroundScope: CoroutineScope, ) { private val activeObservers = AtomicInteger(0) fun observeProducts(categories: Set<String>): Flow<Product> = productRepository .observeProductUpdates() .distinctUntilChanged() .flatMapMerge(concurrency = Int.MAX_VALUE) { flow { emit(productRepository.fetchProduct(it)) } } .filter { it.category in categories } .onStart { activeObservers.incrementAndGet() } .onCompletion { activeObservers.decrementAndGet() } fun activeObserversCount(): Int = activeObservers.get() }

However, this solution has one efficiency issue: every new observer of observeProducts will start a new flow of product updates. This can can be optimized using SharedFlow, which will be introduced later.

Example solution in playground

import coroutines.flow.productservice.TestData.product1 import coroutines.flow.productservice.TestData.product2 import coroutines.flow.productservice.TestData.product3 import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.delay import kotlinx.coroutines.flow.* import kotlinx.coroutines.launch import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.currentTime import kotlinx.coroutines.test.runCurrent import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.Test import java.util.concurrent.atomic.AtomicInteger class ProductService( private val productRepository: ProductRepository, backgroundScope: CoroutineScope, ) { private val activeObservers = AtomicInteger(0) fun observeProducts(categories: Set<String>): Flow<Product> = productRepository .observeProductUpdates() .distinctUntilChanged() .flatMapMerge(concurrency = Int.MAX_VALUE) { flow { emit(productRepository.fetchProduct(it)) } } .filter { it.category in categories } .onStart { activeObservers.incrementAndGet() } .onCompletion { activeObservers.decrementAndGet() } fun activeObserversCount(): Int = activeObservers.get() } interface ProductRepository { // Emits ids of the products that got updated fun observeProductUpdates(): Flow<String> suspend fun fetchProduct(id: String): Product } data class Product( val id: String, val category: String, val name: String, val price: Double, ) @ExperimentalCoroutinesApi class ProductServiceTest { @Test fun `should emit distinct products`() = runTest { val fakeProductRepository = FakeProductRepository() val productService = ProductService(fakeProductRepository, backgroundScope) val categories = setOf(product1.category) val observedProducts = mutableListOf<Product>() productService.observeProducts(categories) .onEach { observedProducts.add(it) } .launchIn(backgroundScope) runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) runCurrent() assertEquals(1, observedProducts.size) assertEquals("1", observedProducts[0].id) } @Test fun `should filter products by category`() = runTest { val fakeProductRepository = FakeProductRepository() val productService = ProductService(fakeProductRepository, backgroundScope) val categories = setOf(product2.category) val observedProducts = mutableListOf<Product>() productService.observeProducts(categories) .onEach { observedProducts.add(it) } .launchIn(backgroundScope) runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) runCurrent() assertEquals(1, observedProducts.size) assertEquals("2", observedProducts[0].id) } @Test fun `should map product IDs to products`() = runTest { val fakeProductRepository = FakeProductRepository() val productService = ProductService(fakeProductRepository, backgroundScope) val categories = setOf(product1.category, product2.category) val observedProducts = mutableListOf<Product>() productService.observeProducts(categories) .onEach { observedProducts.add(it) } .launchIn(backgroundScope) runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) runCurrent() assertEquals(2, observedProducts.size) assertEquals("Smartphone", observedProducts[0].name) assertEquals("Novel", observedProducts[1].name) } @Test fun `should count active observers`() = runTest { val fakeProductRepository = FakeProductRepository() val productService = ProductService(fakeProductRepository, backgroundScope) assertEquals(0, productService.activeObserversCount()) val job1 = backgroundScope.launch { productService.observeProducts(setOf(product1.category)).collect() } runCurrent() assertEquals(1, productService.activeObserversCount()) val job2 = backgroundScope.launch { productService.observeProducts(setOf(product1.category)).collect() } runCurrent() assertEquals(2, productService.activeObserversCount()) job2.cancel() runCurrent() assertEquals(1, productService.activeObserversCount()) job1.cancel() runCurrent() assertEquals(0, productService.activeObserversCount()) } @Test fun `should not complete when there are active observers`() = runTest { val fakeProductRepository = FakeProductRepository() val productService = ProductService(fakeProductRepository, backgroundScope) val categories = setOf(product1.category) val observedProducts = mutableListOf<Product>() val job = backgroundScope.launch { productService.observeProducts(categories) .onEach { observedProducts.add(it) } .collect() } runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) runCurrent() assertEquals(1, observedProducts.size) assertEquals("1", observedProducts[0].id) job.cancel() runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) runCurrent() assertEquals(1, observedProducts.size) } @Test fun `should fetch products concurrently`() = runTest { val fakeProductRepository = FakeProductRepository(fetchProductsDelay = 1000) val productService = ProductService(fakeProductRepository, backgroundScope) val categories = setOf(product1.category, product2.category) val observedProducts = Channel<Product>(Channel.UNLIMITED) productService.observeProducts(categories) .onEach { observedProducts.send(it) } .launchIn(backgroundScope) runCurrent() fakeProductRepository.updatesHasOccurred(product1.id, product1.id, product2.id, product3.id) waiting for all products observedProducts.consumeAsFlow().take(2).collect() should take as much as a single fetch assertEquals(1000, currentTime) } @Test fun `should not limit the number of fetched products`() = runTest { val products = List(1000) { // Big number of products to fetch Product( id = "id$it", category = "ALL", name = "name$it", price = it.toDouble(), ) } val fakeProductRepository = FakeProductRepository(fetchProductsDelay = 1000, products = products) val productService = ProductService(fakeProductRepository, backgroundScope) val observedProducts = Channel<Product>(Channel.UNLIMITED) productService.observeProducts(setOf("ALL")) .onEach { observedProducts.send(it) } .launchIn(backgroundScope) runCurrent() fakeProductRepository.updatesHasOccurred(*products.map { it.id }.toTypedArray()) advanceUntilIdle() waiting for all products observedProducts.consumeAsFlow().take(products.size).collect() should take as much as a single fetch assertEquals(1000, currentTime) } } object TestData { val product1 = Product("1", "electronics", "Smartphone", 500.0) val product2 = Product("2", "books", "Novel", 20.0) val product3 = Product("3", "clothing", "T-Shirt", 15.0) } class FakeProductRepository( private val fetchProductsDelay: Long = 0, private val products: List<Product> = listOf( product1, product2, product3 ) ) : ProductRepository { private val updates = MutableSharedFlow<String>() var observers = 0 private set var productFetchCounter = 0 private set override fun observeProductUpdates(): Flow<String> = updates .onStart { observers++ } .onCompletion { observers-- } override suspend fun fetchProduct(id: String): Product { productFetchCounter++ if (fetchProductsDelay > 0) { delay(fetchProductsDelay) } return products.first { it.id == id } } suspend fun updatesHasOccurred(vararg ids: String) { ids.forEach { updates.emit(it) } } }