In this post we’re going to have some fun with a mind-breaking thing called Reverse State, and explore the limits of laziness in Scala along the way.
When I see something interesting implemented in a “foreign” programming language, I often have a desire to port it to Scala – just out of pure wondering how it would look. And sometimes using familiar language also allows to deeper understand the concepts presented. Some time ago I did it with a great book called “Neural Networks and Deep Learning”: here are most of the exercises from the book written in Scala.
This time a totally different thing caught my eye. It was a really nice article about Reverse State monad and it’s implementation in Haskell. I have never heard about it before, so implementing it in Scala seemed like an exciting exercise.
And it didn’t disappoint, although outcome was not the one I expected 🙂 So I decided to make it a story instead of just plain code. I’ll use the technique that Eugene Yokota applied in his “Learning Scalaz”: we’ll follow the source (in Eugene’s case it was Learn You A Haskell For Great Good) piece by piece, discussing and writing the code along the way. Let’s go!
To really follow along, reader should be familiar with
Traverse typeclasses as well as with
State monad in Scala. Also a quick read of the original article won’t hurt.
There’s a big introductory part that’s dedicated to different ways of scanning a data structure. Let’s not skip this block and use it as a warmup.
Given a list of
Ints, can you produce a cumulative sum of those integers? For example, if we had the list
[2, 3, 5, 7, 11, 13], we want to have
[2, 5, 10, 17, 28, 41].
There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative
ST-based destructive updates and an idiomatic functional style.
This should be simple, there’s
scanLeft in Scala std lib:
For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list
[2, 3, 5, 7, 11, 13], we want to have
[41, 39, 36, 31, 24, 13].
Pretty much the same thing here:
As a Haskell programmer, we have the instinct to generalize things.
Say no more. Scala developers love it too. We’ll use type classes from cats, but scalaz would work the same way:
But how might we implement such a function? Let’s consider
cumulativefirst. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. The
Statemonad then becomes helpful.
State is a well-known concept in Scala, so we can easily follow-up with our implementation of
Comments are required here I guess. “Traversing with State” is a powerful technique to go over some data structure, while accumulating information along the way. Processing of each next element allows you to modify the accumulated State “effect”. In this case we’re just accumulating the running sum (according to the provided
Monoid) and using the same sum as the result value of the state calculation.
runA is analogous to
evalState in Haskell – it evaluates the thing and returns just the result value (ignoring the accumulated state). And, since for stack safety reasons State calculations in cats are wrapped into
Eval, we have to execute it to get our
Ok, let’s now, using the original article’s help, try to implement
cumulativeR with Reverse State.
Enter the reverse state monad
As mind-boggling as it is, let’s try to digest this definition:
The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future.
Oh man… Well, let’s try to at least port the provided implementation to Scala. But I’ll change two things with comparison to the original:
- I’ll swap the results in the signature of the
runFfunction to be consistent with normal
Statein cats, where the state is returned on the left, and the value is on the right.
- To implement
cumulativeRwe only need an
Applicative, so I’ll not try provide an instance of
ReverseStateat this moment.
So this is our
ReverseState applicative. We’re drowning in
Evals, but that’s the cost of stack safety: everything here is really similar to the original
State in cats, except for the
And, actually, no “clever use of laziness” is happening in the
ap. Seems like it will show up in
flatMap, but so far we’re fine without it –
cumulativeR implementation works already:
We can check that it’s output is equivalent to the
Of course, due to laziness, similar example in Haskell will not calculate anything until we explicitly ask for an element of the list or trigger the evaluation somehow else.
Let’s go ahead and implement the
Scan generalization as presented in the original.
Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to
ReverseState) is different.
Since we don’t have proper universal quantification in Scala, I’ll just lift the
x into the type parameters list and name it
S (state). I could make it closer to the original using shapeless
polys, but that’s not the topic of the post.
Here, we simply unwrap a given state monad action, wrap it again in our
ReverseState, do the traversal, then unwrap it again.
I find it beautiful! And it works, although I decided not to present standard deviation and max*min scans here. The former would require a lot of math and the latter needs proper composition abstractions for
Scan, which fall out of the scope of this post.
So that’s it! We implemented everything introduced in original post, and we did it in Scala! Except for…
FlatMap! Where is my FlatMap ???
The true power of state lies in the ability to sequence stateful computations using
flatMap as we know it in Scala). But does it work for
In Haskell it definitely does. Laziness of Haskell runtime allows
bind to be a finite computation. Let’s take a closer look to the definition from the original article:
It’s clear that there’s a circular computational dependency between
a: each of them is calculated in terms of the other. But that is fine – as long as we operate on finite data, at some point next “future” state won’t be needed and Haskell runtime will evaluate only as much as required for the result to be produced.
So what about Scala?
I would be happy to be proven wrong, but after hours of thought and experiments, after trying to wrap pretty much every tiny thing in
Eval, I came to conclusion that there’s no possible way to implement
ReverseState in Scala.
Although there’s a way to encode a circular dependency in Scala, there has to be an explicit exit from the “loop”. In other words, computation of such a circular dependency in Scala will only complete when under some runtime condition the dependency is gone. The reason is simple – JVM runtime is strict, thus it can’t suspend computations, that are not needed right now.
This restriction still allows some pretty interesting laziness tricks, like loeb function, for example. But let’s take a look at how an implementation of
ReverseState might look like in Scala:
The circular dependency in the
result is unconditional – the next leg of calculation is created regardless of any previous results.
Eval won’t help here, because to work inside
Eval we need to sequence it with
flatMap. So we won’t even be able to construct our
Eval computation, since it would require circularly dependant
flatMap calls on
flatMap calls themselves are eager and there’s no way to avoid that.
So, depending on whether we wrap the
Eval.defer, we either get an infinite loop or a stack overflow for programs that involve
Seems like we reached the limits of laziness in Scala here.
There’s one case though where
ReverseState will work properly in Scala. It’s when your state type
S is a lazy data structure (a standard
Stream, for example).
It may seem like some random exceptional fact, but actually it’s the same case of providing the runtime with a condition to stop evaluation and break the circular computational dependency. This time it’s just less explicit and takes the form of
Thanks to Oleg Nizhnik (@odomontois) for pointing me in this direction.
In this post we found out, that
ReverseState is not a
Monad in Scala. Again, I would really love to be proven wrong here, so if you happen to find a working instance – please, ping me!
It’s not a
Monad, but it’s an
Applicative, which means we still can use it in some meaningful computations 🙂
As an example of such, we looked at right-to-left stateful traversals. Big thanks to Zhouyu Qian from Capital Match for his post about
ReverseState in Haskell, that served as a foundation for the post you just read.
Thanks for reading!
Interactive playground with all of the presented code is available here.