Last updated on 04.08.2018
Hello all!
In this post we’re going to have some fun with a mind-breaking thing called Reverse State, and explore the limits of laziness in Scala along the way.
When I see something interesting implemented in a “foreign” programming language, I often have a desire to port it to Scala – just out of pure wondering how it would look. And sometimes using familiar language also allows to deeper understand the concepts presented. Some time ago I did it with a great book called “Neural Networks and Deep Learning”: here are most of the exercises from the book written in Scala.
This time a totally different thing caught my eye. It was a really nice article about Reverse State monad and it’s implementation in Haskell. I have never heard about it before, so implementing it in Scala seemed like an exciting exercise.
And it didn’t disappoint, although outcome was not the one I expected 🙂 So I decided to make it a story instead of just plain code. I’ll use the technique that Eugene Yokota applied in his “Learning Scalaz”: we’ll follow the source (in Eugene’s case it was Learn You A Haskell For Great Good) piece by piece, discussing and writing the code along the way. Let’s go!
Prerequisites
To really follow along, reader should be familiar with Monoid
and Traverse
typeclasses as well as with State
monad in Scala. Also a quick read of the original article won’t hurt.
Warmup
There’s a big introductory part that’s dedicated to different ways of scanning a data structure. Let’s not skip this block and use it as a warmup.
Given a list of
Int
s, can you produce a cumulative sum of those integers? For example, if we had the list[2, 3, 5, 7, 11, 13]
, we want to have[2, 5, 10, 17, 28, 41]
.There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative
ST
-based destructive updates and an idiomatic functional style.
This should be simple, there’s scanLeft
in Scala std lib:
For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list
[2, 3, 5, 7, 11, 13]
, we want to have[41, 39, 36, 31, 24, 13]
.
Pretty much the same thing here:
As a Haskell programmer, we have the instinct to generalize things.
Say no more. Scala developers love it too. We’ll use type classes from cats, but scalaz would work the same way:
But how might we implement such a function? Let’s consider
cumulative
first. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. TheState
monad then becomes helpful.
State is a well-known concept in Scala, so we can easily follow-up with our implementation of cumulative
:
Comments are required here I guess. “Traversing with State” is a powerful technique to go over some data structure, while accumulating information along the way. Processing of each next element allows you to modify the accumulated State “effect”. In this case we’re just accumulating the running sum (according to the provided Monoid
) and using the same sum as the result value of the state calculation.
runA
is analogous to evalState
in Haskell – it evaluates the thing and returns just the result value (ignoring the accumulated state). And, since for stack safety reasons State calculations in cats are wrapped into Eval
, we have to execute it to get our value
out.
Ok, let’s now, using the original article’s help, try to implement cumulativeR
with Reverse State.
Enter the reverse state monad
As mind-boggling as it is, let’s try to digest this definition:
The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future.
Oh man… Well, let’s try to at least port the provided implementation to Scala. But I’ll change two things with comparison to the original:
- I’ll swap the results in the signature of the
runF
function to be consistent with normalState
in cats, where the state is returned on the left, and the value is on the right. - To implement
cumulativeR
we only need anApplicative
, so I’ll not try provide an instance ofMonad
forReverseState
at this moment.
So this is our ReverseState
applicative. We’re drowning in Eval
s, but that’s the cost of stack safety: everything here is really similar to the original State
in cats, except for the ap
function.
And, actually, no “clever use of laziness” is happening in the ap
. Seems like it will show up in flatMap
, but so far we’re fine without it – cumulativeR
implementation works already:
We can check that it’s output is equivalent to the scanRight
-based implementation:
Of course, due to laziness, similar example in Haskell will not calculate anything until we explicitly ask for an element of the list or trigger the evaluation somehow else.
Let’s go ahead and implement the Scan
generalization as presented in the original.
Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to
ReverseState
) is different.
Since we don’t have proper universal quantification in Scala, I’ll just lift the x
into the type parameters list and name it S
(state). I could make it closer to the original using shapeless poly
s, but that’s not the topic of the post.
Here, we simply unwrap a given state monad action, wrap it again in our
ReverseState
, do the traversal, then unwrap it again.
I find it beautiful! And it works, although I decided not to present standard deviation and max*min scans here. The former would require a lot of math and the latter needs proper composition abstractions for Scan
, which fall out of the scope of this post.
So that’s it! We implemented everything introduced in original post, and we did it in Scala! Except for…
FlatMap! Where is my FlatMap ???
The true power of state lies in the ability to sequence stateful computations using bind
(or flatMap
as we know it in Scala). But does it work for ReverseState
?
In Haskell it definitely does. Laziness of Haskell runtime allows bind
to be a finite computation. Let’s take a closer look to the definition from the original article:
It’s clear that there’s a circular computational dependency between future
and a
: each of them is calculated in terms of the other. But that is fine – as long as we operate on finite data, at some point next “future” state won’t be needed and Haskell runtime will evaluate only as much as required for the result to be produced.
So what about Scala?
I would be happy to be proven wrong, but after hours of thought and experiments, after trying to wrap pretty much every tiny thing in Eval
, I came to conclusion that there’s no possible way to implement flatMap
for ReverseState
in Scala.
Although there’s a way to encode a circular dependency in Scala, there has to be an explicit exit from the “loop”. In other words, computation of such a circular dependency in Scala will only complete when under some runtime condition the dependency is gone. The reason is simple – JVM runtime is strict, thus it can’t suspend computations, that are not needed right now.
This restriction still allows some pretty interesting laziness tricks, like loeb function, for example. But let’s take a look at how an implementation of flatMap
for ReverseState
might look like in Scala:
The circular dependency in the result
is unconditional – the next leg of calculation is created regardless of any previous results.
Eval
won’t help here, because to work inside Eval
we need to sequence it with flatMap
. So we won’t even be able to construct our Eval
computation, since it would require circularly dependant flatMap
calls on Eval
. The flatMap
calls themselves are eager and there’s no way to avoid that.
So, depending on whether we wrap the result
into Eval.defer
, we either get an infinite loop or a stack overflow for programs that involve flatMap
-ing ReverseState
.
Seems like we reached the limits of laziness in Scala here.
Update
There’s one case though where flatMap
for ReverseState
will work properly in Scala. It’s when your state type S
is a lazy data structure (a standard Stream
, for example).
It may seem like some random exceptional fact, but actually it’s the same case of providing the runtime with a condition to stop evaluation and break the circular computational dependency. This time it’s just less explicit and takes the form of Stream
‘s laziness.
Thanks to Oleg Nizhnik (@odomontois) for pointing me in this direction.
Conclusion
In this post we found out, that ReverseState
is not a Monad
in Scala. Again, I would really love to be proven wrong here, so if you happen to find a working instance – please, ping me!
It’s not a Monad
, but it’s an Applicative
, which means we still can use it in some meaningful computations 🙂
As an example of such, we looked at right-to-left stateful traversals. Big thanks to Zhouyu Qian from Capital Match for his post about ReverseState
in Haskell, that served as a foundation for the post you just read.
Thanks for reading!
Code
Interactive playground with all of the presented code is available here.