Datatype-Generic Programming in Lean4

2023-09-09

A while back I ran across the paper Datatype-Generic Programming by Jeremy Gibbons and wanted to implement it. Lean4 is a language meant primarily for formalizing mathematics, but also is geared towards more conventional functional programming and so would be a good fit for this.

The goal here is just to go through implementing the idea both to help understand it as well as get some experience with Lean. Whether or not the implementation is good, or whether implementing it at all is a good idea are out-of-scope.

Background - Generic Programming

Typical generic programming (the paper discusses many other kinds) abstracts data structures and functions on them where they only differ by an underlying type.

For example, if we have the type of lists of integers and the type of lists of strings, we might end up with two different append functions that look nearly identical.

inductive ListI where
| Nil : ListI
| Cons : Int  ListI  ListI

def AppendI (a b : ListI) : ListI :=
match a with
| ListI.Nil => b
| ListI.Cons head rest => AppendI rest (ListI.Cons head b)

inductive ListS where
| Nil : ListS
| Cons : String  ListS  ListS

def AppendS (a b : ListS) : ListS :=
match a with
| ListS.Nil => b
| ListS.Cons head rest => AppendS rest (ListS.Cons head b)

By using generic programming we can write a single List type and a single Append function that are both parametized by the underlying type.

inductive List (a : Type) where
| Nil : List a
| Cons : a  List a  List a

def Append {a : Type} (a b : List a) : List a :=
match a with
| List.Nil => b
| List.Cons head rest => Append rest (List.Cons head b)

Background - Datatype-Generic Programming

Similarly, when you have an inductive datatype like a list or a tree we have some common functions we can define on them like map or fold:

inductive List (α : Type) where
| Nil : List α
| Cons : α  List α  List α

def mapL {α β : Type} (f : α  β) : (List α  List β) := fun l1 =>
  match l1 with
  | List.Nil => List.Nil
  | List.Cons head rest => List.Cons (f head) (mapL f rest)

def foldL {α β : Type} (seed : β) (f : β  α  β)  (list : List α) : β :=
  match list with
  | List.Nil => seed
  | List.Cons head rest => f (foldL seed f rest) head

inductive BTree (α : Type) where
| Tip : α  BTree α
| Bin : BTree α  BTree α  BTree α

def mapB {α β : Type} (f : α  β) (t : BTree α) :=
  match t with
  | BTree.Tip a => BTree.Tip (f a)
  | BTree.Bin left right => BTree.Bin (mapB f left) (mapB f right)

def foldB {α β : Type} (g : α  β) (f: β  β  β) (tree : BTree α) : β :=
  match tree with
  | BTree.Tip val => g val
  | BTree.Bin left right => f (foldB g f left) (foldB g f right)

Here you have to squint a little bit, but the two maps and folds are doing very analogous things. The paper emphasizes that the differences between the functions are solely due to the 'shape' of the data structures they are operating on. If we can abstract away that 'shape', then we'd be able to only have one of each function that works on List, Map, and most any other inductive data type we'd need to use.

Implementation

The way to abstract the shape is using the somewhat mysterious Fix datastructure:

unsafe inductive Fix (s : Type  Type  Type) (a : Type) : Type
| In : s a (Fix s a)  Fix s a

-- This is the reverse of the In constructor
unsafe def out (s : Type  Type  Type) (a : Type) (t : Fix s a) : s a (Fix s a) :=
  match t with 
  | Fix.In x => x

We need to define Fix with 'unsafe' because Lean requires all code to terminate and so doesn't allow inductive datatypes with the type of recursion the Fix requires. I'm not sure whether there's a way around this, but for this implementation it doesn't matter much.

Here, the s parameter is the 'shape' that we're trying to abstract and a is the underlying type parameterizing that shape (e.g. Int or String in the above examples).

So what do these 'shapes' look like? Essentially they are identical to the inductive datatype definitions but with an extra parameter passed to recursive cases. We then define the actual datatypes using Fix and these 'shape' types:

inductive ListF (α : Type) (β : Type) : Type where
| Nil : ListF α β
| Cons : α  β  ListF α β

unsafe def List (α : Type) := Fix ListF α

inductive BTreeF (α : Type) (β : Type) : Type where
| Tip : α  BTreeF α β
| Bin : β  β  BTreeF α β

unsafe def BTree (α : Type) := Fix BTreeF α

To make things more concrete let's look at how to create the elements of the actual data types. To do this, we just pass the constructors for the shape type to the Fix.In constructor:

unsafe def List.Nil {a : Type} : List a := Fix.In ListF.Nil
unsafe def List.Cons {a : Type} (head : a) (rest : List a) : List a := Fix.In (ListF.Cons head rest)

unsafe def BTree.Tip {a : Type} (tip : a) := Fix.In (BTreeF.Tip tip)
unsafe def BTree.Bin {a : Type} (l r : BTree a) := Fix.In (BTreeF.Bin l r)

The last thing we need to do before implementing the datatype-generic functions is to implement a 'bimap' for the shape types. This allows us to take a function and apply it to the underlying type of shape types. The behavior is captured by the Bifunctor type class:

class Bifunctor (F : Type  Type  Type) where
(bimap :  {α α' β β'}, (α  α')  (β  β')  F α β  F α' β')
export Bifunctor ( bimap )

Types that are meant to be bifunctors must in theory satisfy a few properties which we will bypass here for brevity. So all we need to do is implement bimap:

instance : Bifunctor ListF where
  bimap := fun {γ : Type} {γ' : Type} {β : Type} {β' : Type} (f : γ  γ') (g : β  β') (t: ListF γ β) => match t with
    | ListF.Nil => ListF.Nil
    | ListF.Cons head rest => ListF.Cons (f head) (g rest)

instance : Bifunctor BTreeF where
  bimap := fun {γ : Type} {γ' : Type} {β : Type} {β' : Type} (f : γ  γ') (g : β  β') (t: BTreeF γ β) => match t with
    | BTreeF.Tip tip => BTreeF.Tip (f tip)
    | BTreeF.Bin l r => BTreeF.Bin (g l) (g r)

Now we'll just implement the datatype-generic functions by copying from the paper without comment:

unsafe def map (α β : Type) [Bifunctor s] (f: α  β) : (Fix s α  Fix s β) :=
  Fix.In  bimap f (map α β f)  (out s α)

unsafe def fold (α β : Type) [Bifunctor s] (f: s α β  β) : (Fix s α  β) :=
  f  bimap id (fold α β f)  (out s α)

And there we have it. As an example usage, let's make our data structures instances of ToString using fold:

unsafe instance [ToString α] : ToString (Fix ListF α) where
  toString (list : List α) : String :=
    fold α String (fun foo : ListF α String => match foo with
    | ListF.Nil => ""
    | ListF.Cons head rest => s!"{head}, {rest}") list

unsafe instance [ToString α] : ToString (Fix BTreeF α) where
  toString (tree: BTree α) : String :=
    fold α String (fun foo : BTreeF α String => match foo with
    | BTreeF.Tip tip => toString tip
    | BTreeF.Bin l r => s!"({l} {r})") tree

Conclusion

Aside from just being practice translating the concept from the paper to code this leaves something to be desired. Aside from the ergonomic issue of having everything downstream of Fix needing to be defined as 'unsafe' there's the core issue for the fold function that requires you break down the 'shape' type into cases.

Still an interesting exercise though.

Appendix: Full Code

class Bifunctor (F : Type  Type  Type) where
(bimap :  {α α' β β'}, (α  α')  (β  β')  F α β  F α' β')
export Bifunctor ( bimap )

unsafe inductive Fix (s : Type  Type  Type) (a : Type) : Type
| In : s a (Fix s a)  Fix s a

unsafe def out (s : Type  Type  Type) (a : Type) (t : Fix s a) : s a (Fix s a) :=
  match t with 
  | Fix.In x => x

unsafe def map (α β : Type) [Bifunctor s] (f: α  β) : (Fix s α  Fix s β) :=
  Fix.In  bimap f (map α β f)  (out s α)

unsafe def fold (α β : Type) [Bifunctor s] (f: s α β  β) : (Fix s α  β) :=
  f  bimap id (fold α β f)  (out s α)

inductive ListF (α : Type) (β : Type) : Type where
| Nil : ListF α β
| Cons : α  β  ListF α β
deriving Repr

instance : Bifunctor ListF where
  bimap := fun {γ : Type} {γ' : Type} {β : Type} {β' : Type} (f : γ  γ') (g : β  β') (t: ListF γ β) => match t with
    | ListF.Nil => ListF.Nil
    | ListF.Cons head rest => ListF.Cons (f head) (g rest)

unsafe def List (α : Type) := Fix ListF α
unsafe def List.Nil {α : Type} : List α := Fix.In ListF.Nil
unsafe def List.Cons {α : Type} (head : α) (rest : List α) : List α := Fix.In (ListF.Cons head rest)

unsafe instance [ToString α] : ToString (Fix ListF α) where
  toString (list : List α) : String :=
    fold α String (fun foo : ListF α String => match foo with
    | ListF.Nil => ""
    | ListF.Cons head rest => s!"{head}, {rest}") list

unsafe def testList : List Int :=
  (List.Cons 5 (List.Cons 4 (List.Cons 87 (List.Nil))))

#eval testList

#eval (map Int Int (fun x => x + 1) testList)

#eval fold Int Int (fun foo : ListF Int Int => match foo with
  | ListF.Nil => 0
  | ListF.Cons head rest => head + rest) testList

inductive BTreeF (α : Type) (β : Type) : Type where
| Tip : α  BTreeF α β
| Bin : β  β  BTreeF α β

instance : Bifunctor BTreeF where
  bimap := fun {γ : Type} {γ' : Type} {β : Type} {β' : Type} (f : γ  γ') (g : β  β') (t: BTreeF γ β) => match t with
    | BTreeF.Tip tip => BTreeF.Tip (f tip)
    | BTreeF.Bin l r => BTreeF.Bin (g l) (g r)

unsafe def BTree (α : Type) := Fix BTreeF α
unsafe def BTree.Tip {α : Type} (tip : α) := Fix.In (BTreeF.Tip tip)
unsafe def BTree.Bin {α : Type} (l r : BTree α) := Fix.In (BTreeF.Bin l r)

unsafe instance [ToString α] : ToString (Fix BTreeF α) where
  toString (tree: BTree α) : String :=
    fold α String (fun foo : BTreeF α String => match foo with
    | BTreeF.Tip tip => toString tip
    | BTreeF.Bin l r => s!"({l} {r})") tree

unsafe def testTree : BTree Int :=
  BTree.Bin (BTree.Bin (BTree.Tip 5) (BTree.Tip 4)) (BTree.Tip 87)

#eval testTree

#eval (map Int Int (fun x => x + 1) testTree)

#eval fold Int Int (fun foo : BTreeF Int Int => match foo with
  | BTreeF.Tip tip => tip
  | BTreeF.Bin l r => l + r) testTree