Reverse State Monad in Scala. Is it possible?

Hello all!

In this post we’re going to have some fun with a mind-breaking thing called Reverse State, and explore the limits of laziness in Scala along the way.

When I see something interesting implemented in a “foreign” programming language, I often have a desire to port it to Scala – just out of pure wondering how it would look. And sometimes using familiar language also allows to deeper understand the concepts presented. Some time ago I did it with a great book called “Neural Networks and Deep Learning”: here are most of the exercises from the book written in Scala.

This time a totally different thing caught my eye. It was a really nice article about Reverse State monad and it’s implementation in Haskell. I have never heard about it before, so implementing it in Scala seemed like an exciting exercise.

And it didn’t disappoint, although outcome was not the one I expected 🙂 So I decided to make it a story instead of just plain code. I’ll use the technique that Eugene Yokota applied in his “Learning Scalaz”: we’ll follow the source (in Eugene’s case it was Learn You A Haskell For Great Good) piece by piece, discussing and writing the code along the way. Let’s go!

Prerequisites

To really follow along, reader should be familiar with Monoid and Traverse typeclasses as well as with State monad in Scala. Also a quick read of the original article won’t hurt.

Warmup

There’s a big introductory part that’s dedicated to different ways of scanning a data structure. Let’s not skip this block and use it as a warmup.

Given a list of Ints, can you produce a cumulative sum of those integers? For example, if we had the list [2, 3, 5, 7, 11, 13], we want to have [2, 5, 10, 17, 28, 41].

There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative ST-based destructive updates and an idiomatic functional style.

This should be simple, there’s scanLeft in Scala std lib:

 

For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list [2, 3, 5, 7, 11, 13], we want to have [41, 39, 36, 31, 24, 13].

Pretty much the same thing here:

 

As a Haskell programmer, we have the instinct to generalize things.

Say no more. Scala developers love it too. We’ll use type classes from cats, but scalaz would work the same way:

But how might we implement such a function? Let’s consider cumulative first. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. The State monad then becomes helpful.

State is a well-known concept in Scala, so we can easily follow-up with our implementation of cumulative:

Comments are required here I guess. “Traversing with State” is a powerful technique to go over some data structure, while accumulating information along the way. Processing of each next element allows you to modify the accumulated State “effect”. In this case we’re just accumulating the running sum (according to the provided Monoid) and using the same sum as the result value of the state calculation.

runA is analogous to evalState in Haskell – it evaluates the thing and returns just the result value (ignoring the accumulated state). And, since for stack safety reasons State calculations in cats are wrapped into Eval, we have to execute it to get our value out.

Ok, let’s now, using the original article’s help, try to implement cumulativeR with Reverse State.

Enter the reverse state monad

As mind-boggling as it is, let’s try to digest this definition:

The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future.

Image result for smart guy meme

Oh man… Well, let’s try to at least port the provided implementation to Scala. But I’ll change two things with comparison to the original:

  1. I’ll swap the results in the signature of the runF function to be consistent with normal State in cats, where the state is returned on the left, and the value is on the right.
  2. To implement cumulativeR we only need an Applicative, so I’ll not try provide an instance of Monad for ReverseState at this moment.

So this is our ReverseState applicative. We’re drowning in Evals, but that’s the cost of stack safety: everything here is really similar to the original State in cats, except for the ap function.

And, actually, no “clever use of laziness” is happening in the ap. Seems like it will show up in flatMap, but so far we’re fine without it  –  cumulativeR implementation works already:

We can check that it’s output is equivalent to the scanRight-based implementation:

Of course, due to laziness, similar example in Haskell will not calculate anything until we explicitly ask for an element of the list or trigger the evaluation somehow else.

Let’s go ahead and implement the Scan generalization as presented in the original.

Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to ReverseState) is different.

Since we don’t have proper universal quantification in Scala, I’ll just lift the x into the type parameters list and name it S (state). I could make it closer to the original using shapeless polys, but that’s not the topic of the post.

Here, we simply unwrap a given state monad action, wrap it again in our ReverseState, do the traversal, then unwrap it again.

I find it beautiful! And it works, although I decided not to present standard deviation and max*min scans here. The former would require a lot of math and the latter needs proper composition abstractions for Scan, which fall out of the scope of this post.

So that’s it! We implemented everything introduced in original post, and we did it in Scala! Except for…

FlatMap! Where is my FlatMap ???

The true power of state lies in the ability to sequence stateful computations using bind (or flatMap as we know it in Scala). But does it work for ReverseState?

In Haskell it definitely does. Laziness of Haskell runtime allows bind to be a finite computation. Let’s take a closer look to the definition from the original article:

It’s clear that there’s a circular computational dependency between future and a: each of them is calculated in terms of the other. But that is fine – as long as we operate on finite data, at some point next “future” state won’t be needed and Haskell runtime will evaluate only as much as required for the result to be produced.

So what about Scala?

I would be happy to be proven wrong, but after hours of thought and experiments, after trying to wrap pretty much every tiny thing in Eval, I came to conclusion that there’s no possible way to implement flatMap for ReverseState in Scala.

Although there’s a way to encode a circular dependency in Scala, there has to be an explicit exit from the “loop”. In other words, computation of such a circular dependency in Scala will only complete when under some runtime condition the dependency is gone. The reason is simple –  JVM runtime is strict, thus it can’t suspend computations, that are not needed right now.

This restriction still allows some pretty interesting laziness tricks, like loeb function, for example. But let’s take a look at how an implementation of flatMap for ReverseState might look like in Scala:

The circular dependency in the result is unconditional – the next leg of calculation is created regardless of any previous results.

Eval won’t help here, because to work inside Eval we need to sequence it with flatMap. So we won’t even be able to construct our Eval computation, since it would require circularly dependant flatMap calls on Eval. The flatMap calls themselves are eager and there’s no way to avoid that.

So, depending on whether we wrap the result into Eval.defer, we either get an infinite loop or a stack overflow for programs that involve flatMap-ing ReverseState.

Seems like we reached the limits of laziness in Scala here.

Update

There’s one case though where flatMap for ReverseState will work properly in Scala. It’s when your state type S is a lazy data structure (a standard Stream, for example).

It may seem like some random exceptional fact, but actually it’s the same case of providing the runtime with a condition to stop evaluation and break the circular computational dependency. This time it’s just less explicit and takes the form of Stream‘s laziness.

Thanks to Oleg Nizhnik (@odomontois) for pointing me in this direction.

Conclusion

In this post we found out, that ReverseState is not a Monad in Scala. Again, I would really love to be proven wrong here, so if you happen to find a working instance – please, ping me!

It’s not a Monad, but it’s an Applicative, which means we still can use it in some meaningful computations 🙂

As an example of such, we looked at right-to-left stateful traversals. Big thanks to Zhouyu Qian from Capital Match for his post about ReverseState in Haskell, that served as a foundation for the post you just read.

Thanks for reading!

Code

Interactive playground with all of the presented code is available here.