article banner

Collection processing in Kotlin: Folding and reducing

This is a chapter from the book Functional Kotlin. You can find it on LeanPub or Amazon. It is also available as a course.

// fold implementation from Kotlin stdlib inline fun <T, R> Iterable<T>.fold( initial: R, operation: (acc: R, T) -> R ): R { var accumulator = initial for (element in this) { accumulator = operation(accumulator, element) } return accumulator }

fold is the most universal method in our collection processing toolbox. We use it rarely because Kotlin standard library has already provided most important aggregate operations for us, but if we are missing a method for a specific task, fold is at our service.

Let's see it practice. fold is a method that accumulates all elements into a single variable (called an "accumulator") using a defined operation. For instance, let's say that our collection contains the numbers from 1 to 4, our initial accumulator value is 0, and our operation is addition. So fold will:

  • add the first value 1 to the initial accumulator value 0,
  • then it will add the result 1 to the next value 2,
  • then it will add the result 3 to the next value 3,
  • then it will add the result 6 to the next value 4,
  • and the result is 10.

As you can see, fold(0) { acc, i -> acc + i } calculates the sum of all the numbers.

Since you can specify the initial value, you can decide the result type. If your initial value is an empty string and your operation is addition, then the result will be a "1234" string.

fun main() { val numbers = listOf(1, 2, 3, 4) val sum = numbers.fold(0) { acc, i -> acc + i } println(sum) // 10 val joinedString = numbers.fold("") { acc, i -> acc + i } println(joinedString) // 1234 val product = numbers.fold(1) { acc, i -> acc * i } println(product) // 24 }

fold is very universal. Nearly all collection processing methods can be implemented using it.

// simplified `filter` implemented with `fold` inline fun <T> Iterable<T>.filter( predicate: (T) -> Boolean ): List<T> = fold(emptyList()) { acc, e -> if (predicate(e)) acc + e else acc } // simplified `map` implemented with `fold` inline fun <T, R> Iterable<T>.map( transform: (T) -> R ): List<R> = fold(emptyList()) { acc, e -> acc + transform(e) } // simplified `flatMap` implemented with `fold` inline fun <T, R> Iterable<T>.flatMap( transform: (T) -> Iterable<R> ): List<R> = fold(emptyList()) { acc, e -> acc + transform(e) }

On the other hand, thanks to the fact that the Kotlin standard library has so many collection processing functions, we rarely need to use fold. Even the functions we presented before that calculate a sum and join elements into a string have dedicated methods.

fun main() { val numbers = listOf(1, 2, 3, 4, 5) println(numbers.sum()) // 15 println(numbers.joinToString(separator = "")) // 12345 }

There is currently no standard library method to calculate the product of all the numbers in a collection, so this is where fold can be used. We might use it directly, or we might use it to implement the product method ourselves.

fun Iterable<Int>.product(): Int = fold(1) { acc, i -> acc * i }

If you want to reverse the order of accumulation (to start from the end of the collection), use foldRight.

In some situations, you might want to have not only the result of fold accumulations but also all the intermediate values. For that, you can use the runningFold method or its alias1 scan.

fun main() { val numbers = listOf(1, 2, 3, 4) println(numbers.fold(0) { acc, i -> acc + i }) // 10 println(numbers.scan(0) { acc, i -> acc + i }) // [0, 1, 3, 6, 10] println(numbers.runningFold(0) { acc, i -> acc + i }) // [0, 1, 3, 6, 10] println(numbers.fold("") { acc, i -> acc + i }) // 1234 println(numbers.scan("") { acc, i -> acc + i }) // [, 1, 12, 123, 1234] println(numbers.runningFold("") { acc, i -> acc + i }) // [, 1, 12, 123, 1234] println(numbers.fold(1) { acc, i -> acc * i }) // 24 println(numbers.scan(1) { acc, i -> acc * i }) // [1, 1, 2, 6, 24] println(numbers.runningFold(1) { acc, i -> acc * i }) // [1, 1, 2, 6, 24] }

runningFold(init, oper).last() and scan(init, oper).last() always give the same result as fold(init, oper).

reduce

// simplified reduce implementation from Kotlin stdlib public inline fun <S, T : S> Iterable<T>.reduce( operation: (acc: S, T) -> S ): S { val iterator = this.iterator() if (!iterator.hasNext()) throw UnsupportedOperationException( "Empty collection can't be reduced." ) var accumulator: S = iterator.next() while (iterator.hasNext()) { accumulator = operation(accumulator, iterator.next()) } return accumulator }

reduce is a very similar function to fold: it also accumulates all elements using a defined transformation. The difference is that in reduce we do not define the initial value, and so reduce uses the first element as the initial value. There are two consequences of this fact:

  • If a collection is empty, reduce throws an exception. If we are not certain that a collection has elements, we should use reduceOrNull , which returns null for an empty collection.
  • The result from reduce must be of the same type as its elements.
  • reduce is slightly faster than fold because it has one operation less to do.

fun main() { val numbers = listOf(1, 2, 3, 4, 5) println(numbers.fold(0) { acc, i -> acc + i }) // 15 println(numbers.reduce { acc, i -> acc + i }) // 15 println(numbers.fold("") { acc, i -> acc + i }) // 12345 // Here `reduce` cannot be used instead of `fold` println(numbers.fold(1) { acc, i -> acc * i }) // 120 println(numbers.reduce { acc, i -> acc * i }) // 120 }

list.reduce(oper) is a lot like list.drop(1).fold(list[0], oper).

In general, I prefer using fold whenever there is a "zero" value because fold does not face the risk of an empty collection and it is able to control the result type.

Just like for fold, there is runningReduce and reduceRight.

sum

// simplified sample sum implementation from Kotlin stdlib fun Iterable<Int>.sum(): Int { var sum: Int = 0 for (element in this) { sum += element } return sum } // simplified sample sumOf implementation from Kotlin stdlib inline fun <T> Iterable<T>.sumOf( selector: (T) -> Int ): Int { var sum: Int = 0.toInt() for (element in this) { sum += selector(element) } return sum }

I mentioned that there is already a function to calculate the sum of all the numbers in a collection, and its name is sum. It is implemented for all the basic ways of representing numbers, like Int, Long, Double, etc.

fun main() { val numbers = listOf(1, 6, 2, 4, 7, 1) println(numbers.sum()) // 21 val doubles = listOf(0.1, 0.6, 0.2, 0.4, 0.7) println(doubles.sum()) // 1.9999999999999998 // It is not 2, due to limited JVM double representation val bytes = listOf<Byte>(1, 4, 2, 4, 5) println(bytes.sum()) // 16 }

When you have a list of elements and you want to calculate the sum of one of their properties, you could first map the elements onto the values of these properties, but it is more efficient to use sumOf, which extracts a countable value for each element and then sums these values.

import java.math.BigDecimal data class Player( val name: String, val points: Int, val money: BigDecimal, ) fun main() { val players = listOf( Player("Jake", 234, BigDecimal("2.30")), Player("Megan", 567, BigDecimal("1.50")), Player("Beth", 123, BigDecimal("0.00")), ) println(players.map { it.points }.sum()) // 924 println(players.sumOf { it.points }) // 924 // Works for `BigDecimal` as well println(players.sumOf { it.money }) // 3.80 }
1:

In this chapter, by aliases we will mean functions with exactly the same meaning.