Blog Archive

Saturday, December 01, 2007

Backwards State, or: The Power of Laziness

There's been a recent discussion of Automatic Differentiation in Haskell recently, which somehow found me reading Jerzy Karczmarczuk's paper "Lazy Time Reversal, and Automatic Differentiation," which then cited Philip Wadler's "The essence of functional programming" for the introduction of the backwards state monad, which I reproduce here because I think it's neat.

I'm going to assume that you're familiar with the Haskell state monad - in summary an action in the state monad is a function of the previous state, and produces a result paired with the next state.

The backwards state monad differs from this in that the flow of the state through the execution is revere to the flow of the results - that is, an action in the backwards state monad takes in the final value of the state and produces a result and the initial value.

This post is literate Haskell post - you should be able to copy and past it into a .lhs file and play with it in a Haskell interpreter. I use GHCi.

To that end, here's some of the up-front boilerplate so this all works:

> {-# LANGUAGE FlexibleInstances,
> MultiParamTypeClasses,
> RecursiveDo
> #-}
> import Data.List
> import Control.Monad.State

An Example


Here's the exercise: Given a tree of items, transform the tree to a tree of Ints such that each element is mapped to an Int, starting at 0. If an element occurs more than once in the tree, it must be mapped to the same Int each time.

The solution given in Control.Monad.State.Lazy does a walk of the tree, and carries around a list of all of the elements seen so far using the state monad. Each node is mapped to its position in this list. That is, the first node seen is mapped to 0, the second to 1, etc..

But what if I wanted to switch that up? What if wanted the last node seen in the walk mapped to 0, the second to last mapped to 1, and so on? How much would I need to change in the already existing solution given in Control.Monad.State.Lazy?

Not much! I'd just need to use the backwards state monad, where the state flows backwards through the thread of execution.

This is what the modified solution would look like:

> data Tree a = Nil | Node a (Tree a) (Tree a) deriving (Show, Eq)
> type Table a = [a]

> numberTree :: Eq a => Tree a -> StateB (Table a) (Tree Int)
> numberTree Nil = return Nil
> numberTree (Node x t1 t2)
> = do num <- atomically $ numberNode x
> nt1 <- numberTree t1
> nt2 <- numberTree t2
> return (Node num nt1 nt2)
> where
> numberNode :: Eq a => a -> State (Table a) Int
> numberNode x
> = do table <- get
> (newTable, newPos) <- return (nNode x table)
> put newTable
> return newPos

> nNode:: (Eq a) => a -> Table a -> (Table a, Int)
> nNode x table
> = case elemIndex x table of
> Nothing -> (table ++ [x], length table)
> Just i -> (table, i)
And an evaluation function:

> numTree :: (Eq a) => Tree a -> Tree Int
> numTree t = evalStateB (numberTree t) []
Some test data:

> testTree = Node "Zero" (Node "One" (Node "Two" Nil Nil) (Node "One" (Node "Three" Nil Nil) Nil)) Nil
Executing numTree testTree will produce the output:
Node 3 (Node 1 (Node 2 Nil Nil) (Node 1 (Node 0 Nil Nil) Nil)) Nil
Which is exactly what we wanted!

This code is almost exactly the same as the solution given to the in-order problem in the source to Control.State.Lazy, the only changes are the use of the function evalStateB instead of the familiar evalState, and the use of the function atomically, and the StateB monad. The implementation of these will be explained bellow.

First the API, then the implementation.

The API


We have the new monad StateB s, where s is the type of the stored state.

StateB s is an instance of MonadState s, so get and put are as expected.

There is also:

> -- runStateB :: StateB s a -> s -> (a, s)
> evalStateB :: StateB s a -> s -> a
> execStateB :: StateB s a -> s -> s

which should look familiar. The trick is that the state s passed in to these functions is the final state, and the state returned is the initial state. In the example above, remember that the last element seen in the walk was given the first label, and the first element seen in the walk was given the last.

The default implementation of modify in Control.Monad.State.Class is implemented as follows:

-- modify :: MonadState s m => (s -> s) -> m ()
-- modify f = do
-- s <- get
-- put (f s)
In the StateB monad, this code will bottom-out, because of the circular data dependency of the two monadic actions - in the backwards state monad, (>>=) passes the result forward and the state backwards, which means that the above code has a nice loop where the first line grabs the updated state from the second line and tries to pass it in as an argument to the second line.

To make this work, we need a version of modify specific to StateB:

> modifyB :: (s -> s) -> StateB s ()
But if you want to modify the state and return the result, you'll need something more sophisticated:

> atomically :: State s a -> StateB s a
atomically converts an action under the normal state monad to a single action under StateB, allowing you do do complex updates to the state easily without bottoming out (using mdo notation also works).

Implementation


The base of the implementation is taken directly from Wadler's paper.

The StateB monad is almost the same as the State monad - each action of type a is a function of type \s -> (a,s). The difference is in the implementation of (>>=).

Let's start with the monad:

> newtype StateB s a = StateB {runStateB :: s -> (a,s)}

> instance Monad (StateB s) where
> return = StateB . unitS
> (StateB m) >>= f = StateB $ m `bindS` (runStateB . f)
Because wrapping and unwrapping the newtype annoys me, all of that is confined to the exported functions (like return and (>>=)). The functions that deal directly with the underlying type all have an 'S' suffix.

> m `bindS` k = \s2 -> let (a, s0) = m s1
> (b, s1) = k a s2
> in (b, s0)

> unitS a = \s2 -> (a, s2)
As you can see, the passed in state is acted on by the RHS of bindS, the intermediate state is consumed by the LHS, and the LHS produces the final state, s0. It looks too simple to work, but it does.
And the other API functions:

> execStateB m = snd . runStateB m

> evalStateB m = fst . runStateB m

> modifyB = StateB . modify'
> where modify' f = \s -> ((), f s)

> atomically = StateB . runState
Just for funsies:

> instance Functor (StateB s) where
> fmap f m = StateB $ mapS f (runStateB m)

> mapS f m = \s -> let (a, s') = m s in (f a, s')

> instance MonadState s (StateB s) where
> get = StateB get'
> where get' = \s -> (s,s)
>
> put = StateB . put'
> where put' s = const ((),s)

> instance MonadFix (StateB s) where
> mfix = StateB . mfixS . (runStateB .)

> mfixS f = \s2 -> let (a,s0) = (f b) s1
> (b,s1) = (f a) s2
> in (b,s0)

The transformer


Now a treat for those of you still paying attention. I haven't really tested this, but it looks like it should work and that's good enough for me. A lot of this is in the style of the sources for Control.Monad.State.Lazy.

> newtype StateBT s m a = StateBT {runStateBT :: s -> m (a,s)}

> unitST a = \s -> return (a,s)

> m `bindST` k = \s2 -> mdo ~(a,s0) <- m s1
> ~(b,s1) <- k a s2
> return (b,s0)

> execStateBT :: Monad m => StateBT s m a -> s -> m s
> execStateBT m s = do ~(_,s') <- runStateBT m s
> return s'

> evalStateBT :: Monad m => StateBT s m a -> s -> m a
> evalStateBT m s = do ~(a,_) <- runStateBT m s
> return a

> modifyBT :: Monad m => (s -> s) -> StateBT s m ()
> modifyBT = StateBT . modify'
> where modify' f = \s -> return ((),f s)

> atomicallyT :: Monad m => State s a -> StateBT s m a
> atomicallyT m = StateBT $ \s-> return $ runState m s

> atomicallyTM :: Monad m => StateT s m a -> StateBT s m a
> atomicallyTM = StateBT . runStateT

> mapST f m = \s -> do ~(a,s') <- m s
> return (f a,s')

> liftST m = \s -> do a <- m
> return (a,s)

> mfixST f = \s2 -> mdo ~(a,s0) <- (f b) s1
> ~(b,s1) <- (f a) s2
> return (b,s0)

> instance Monad m => Functor (StateBT s m) where
> fmap f m = StateBT $ mapST f (runStateBT m)

> instance MonadFix m => Monad (StateBT s m) where
> return = StateBT . unitST
> (StateBT m) >>= f = StateBT $ m `bindST` (runStateBT . f)
> fail = StateBT . const . fail

> instance MonadTrans (StateBT s) where
> lift = StateBT . liftST

> instance MonadFix m => MonadState s (StateBT s m) where
> get = StateBT get'
> where get' = \s -> return (s,s)
>
> put = StateBT . put'
> where put' s = const $ return ((),s)

> instance MonadFix m => MonadFix (StateBT s m) where
> mfix = StateBT . mfixST . (runStateBT .)

Listening:

Watching:

  • House
  • Ride Back