Tuesday, April 8, 2014

Kotlin: Pattern matching

The past year I have evolved a simple pattern matching scheme that I now automatically incorporate into new Kotlin class hierarchies. The scheme is trivial to include, costs nothing to program and offers some beneficial features. Its development is thanks to the many nice features offered by Kotlin.

Consider the development of an immutable list type. The architecture comprises the trait ListIF, the abstract class ListAB, and the two concrete classes Nil and Cons. The ListIF trait operates as an interface, introducing the behaviours supported. The abstract class ListAB is the home for the implementation of most of the common behaviours. The concrete classes Nil and Cons are used to create list instances.

The implementation for this architecture is shown in Example 1. Note the match member function introduced in the trait. This is the pattern matching scheme for this class hierarchy. The function is given two parameters, one for each concrete class. The first parameter is a function to convert a Nil into the result type B. The second parameter is a function from a Cons to the result type B.

The concrete subclasses Nil and Cons override the definition for match. In the class Nil the first parameter is invoked against the Nil instance. In the class Cons the second parameter is invoked against the Cons instance. This pattern matching scheme is effectively the strategy design pattern.

Example 1: Pattern matching

package example1

import kotlin.test.*

trait ListIF<A> {
    fun size(): Int
    fun length(): Int
    fun interleave(xs: ListIF<A>): ListIF<A>

    fun <B> match(nil: (Nil<A>) -> B, cons: (Cons<A>) -> B): B
}

abstract class ListAB<A> : ListIF<A> {
    override fun size(): Int {...}

    override fun length(): Int {...}
    
    override fun interleave(xs: ListIF<A>): ListIF<A> {...}
}

class Nil<A> : ListAB<A>() {
    override fun <B> match(nil: (Nil<A>) -> B, cons: (Cons<A>) -> B): B = nil(this)
}

class Cons<A>(val hd: A, val tl: ListIF<A>) : ListAB<A>() {
    override fun <B> match(nil: (Nil<A>) -> B, cons: (Cons<A>) -> B): B = cons(this)
}

fun main(args: Array<String>) {
    val xs: ListIF<Int> = Cons(1, Cons(3, Cons(5, Nil<Int>())))
    val ys: ListIF<Int> = Cons(2, Cons(4, Nil<Int>()))

    assertEquals(3, xs.size())
    assertEquals(2, ys.length())
    assertEquals(4, xs.interleave(ys).length())
}

The definition for member function size is:

override fun size(): Int {
    return this.match(
        {(nil: Nil<A>) -> 0},
        {(cons: Cons<A>) -> 1 + cons.tl.size()}
    )
}

A match is performed against the recipient list and if it is empty then zero is returned. If the list is non empty then we return 1 plus the size of the remaining list. Of course we could achieve the same using a when clause and a number of is selector clauses. However, it is incumbent on the programmer to provide all the necessary choices. With the match function if you omit a choice you get a compiler error.

The second function literal given to match reveals another feature. The formal parameter cons is of type Cons and is effectively a smart cast of this. Hence in the body of this function literal we can reference the tl property of cons.

The definition for member function interleave shows how we handle nested pattern matching:

override fun interleave(xs: ListIF<A>): ListIF<A> {
    return this.match(
        {(nil1: Nil<A>) -> Nil<A>()},
        {(cons1: Cons<A>) ->
            xs.match(
                {(nil2: Nil<A>) -> Nil<A>()},
                {(cons2: Cons<A>) -> Cons(cons1.hd, Cons(cons2.hd, cons1.tl.interleave(cons2.tl)))}
            )
        }
    )
}

Interleaving the values from two lists requires a number of considerations. They are fully captured by the nested pattern matching. If the first list (this) is empty then an empty list is delivered. If the first list is not empty then we look to the second list (xs). If this is empty then an empty list is returned. Otherwise, for two non empty lists we prefix the head of the first and the head of the second on to interleaving their tails. As well as a type safe implementation of our logic it also allows us to reason about its correctness before we test it.

The recursive definition of the list data type naturally leads to a recursive definition of the classes and its member functions. Here is member function length:

override fun length(): Int {
    fun recLength(xs: ListIF<A>, acc: Int): Int {
        return xs.match(
            {(nil: Nil<A>) -> acc},
            {(cons: Cons<A>) -> recLength(cons.tl, 1 + acc)}
        )
    }

    return recLength(this, 0)
}

The nested function recLength has an accumulating parameter so that the recursive call might exploit tail call optimization. If I understand the Kotlin annotation tailRecursive this does not apply here since the recursive call is nested in the function literal. Perhaps the clever people at IntelliJ could find a way so that I can have my cake and eat it.

Of course not all my class hierarchies are defined recursively and so this issue does not arise. The pattern matching scheme is inexpensive to implement, is type safe, supports smart casts and a clarity in function bodies that we can reason about their correctness.