Stream Fuse Carefully

The UW reading group in programming languages recently read “Exploiting vector instructions with generalized stream fusion”, from ICFP 20131. The paper shows a form of stream fusion which can use special vector instructions in the CPU (also bulk-copy operations such as memcpy). James and I noticed an interesting gotcha in using stream fusion, which we’re writing about here.

What is stream fusion?

Stream fusion is a way to avoid intermediate data structures when writing code in a functional style. Suppose we have the code

foo xs = sum $ filter even $ map floor $ xs

This code will be slow, because the map and floor invocations allocate a new list. Instead, it would be better to combine the iteration over the intermediate lists into one giant recursive function:

fooᵣ xs = fooᵣ' xs 0
    fooᵣ' (x:xs) s =
      let r = floor x in
      if even r then fooᵣ' xs (s + r) else fooᵣ' xs r
    fooᵣ' [] s = s

This code is faster and avoids allocation but is ugly to write. The functional version used many small functions together to achieve its results; our recursive version had to reimplement all of them.

Stream fusion achieves the same result as our recursive code by using clever versions of map, filter, and sum. The map, filter, and sum operations on lists become the mapₛ, filterₛ, and sumₛ operations on streams. A stream is a stateful function which produces either elements of the stream or a special “skip” instruction. The skip instruction allows functions like filterₛ skip certain elements. The Step type is parameterized over the element type and also the state s used by the stream generator.

data Step α s = Yield α s | Skip s | Done

A stream, then, internally has some state type (∃ s), and contains a starting state s and a function s → Step α s which produces steps in the stream.

data Stream α = ∃ s · Stream {next: s → Step α s, start: s}

We can turn lists into streams and back:

stream :: [α] → Stream α
stream l = Stream { next: iter, start: l }
    iter (x:xs) = Yield x xs
    iter [] = Done

unstream :: Stream α → [α]
unstream (Stream { next: iter, start: s₀ }) = go iter s₀
    go iter s =
      case iter s of
	Yield elt s' → elt : go iter s'
	Skip s' → go iter s'
	Done → []

As you can imagine, we can write mapₛ, filterₛ, and sumₛ as functions on Stream α that do the expected thing. So we can now write

map f = unstream ∘ mapₛ f ∘ stream
filter f = unstream ∘ filterₛ f ∘ stream
sum f = sumₛ ∘ stream

The benefit of this encoding is that our original functional code becomes

foo xs = sumₛ ∘ stream $ unstream ∘ filterₛ even ∘ stream
	 $ unstream ∘ mapₛ floor ∘ stream $ xs

Now, in this code the only intermediate data structures are produced by the stream ∘ unstream blocks. Since these calls don’t do anything to our list elements, we can ignore them, using GHC’s RULES directive:

"STREAM stream/unstream fusion" forall s.
    stream (unstream s) = s

GHC will then eliminate the stream ∘ unstream calls and then inline the mapₛ, filterₛ, and sumₛ blocks. After optimizing away the Step data structures, the same tight code as fooᵣ will be produced. And all of this happens automatically, so our code is clean even though it runs quickly.

What’s the problem?

As implemented, stream fusion is a great optimization which achieves great performance. There’s no problem with it. But there is a subtle gotcha in implementing it. You see, stream ∘ unstream is not exactly equal to the identity.

The unstream function skips Skip steps, because those don’t carry elements of the stream. The stream function, since it operates on an intermediate list, cannot restore them. But when stream ∘ unstream is replaced with the identity, those skip steps hang around.

This isn’t a problem for map or filter, because both of those skip when their input stream skips, without changing their internal state in any way. But some care needs to be taken to make sure that every stream function has this behavior.

A stream function might depend on the presence of skips by, for example, replacing them with previous elements of the stream. Of course, this is a constructed example; you wouldn’t accidentally write this code. But it demonstrates the principle:

fillₛ (Stream { next: iter, start: s₀ }) = Stream { next: go, start: (s₀, Nothing) }
    go (s, prev) =
      case iter s of
	Yield x s' → Yield x (s', Just x)
	Skip s' →
	  case prev of
	    Nothing → Skip (s', prev)
	    Just x → Yield x (s', prev)
	Done → Done

If the input stream yields elements of type α, our state has type (s, Maybe α). The second part of the state is the most recent element of the input stream; it starts as Nothing (because there isn’t yet a previous element), and is set to Just x every time the input stream produces an x. When the input stream skips an element, we output the previous element if one exists.

For example,

unstream $ fillₛ $ filterₛ even $ stream [1..7]
⇒ [2, 2, 4, 4, 6]

Filtering away the even elements results in a stream containing

skip ; yield 2 ; skip ; yield 4 ; skip ; yield 6

Then the fill function fills in the later two skips.

However, if we go ahead and define

fill = stream ∘ fillₛ ∘ unstream

we run into trouble. Without the rewrite rule,

fill $ filter even $ [1..7] ⇒ [2, 4, 6]

because the intermediate stream ∘ unstream removes all of the skips from the stream. With the rewrite rule, we instead see

fill $ filter even $ [1..7] ⇒ [2, 2, 4, 4, 6]

So the use of the rewrite rule changes the meaning of our program. That’s a problem, because rewrite rules can be turned on or off by compiler flags or optimization level.

But would this actually show up?

Of course, fillₛ is not a function one would normally define. Yet there are reasonable function where a type might lead to different behavior with and without the stream ∘ unstream rewrite rule.

Consider the enumerate function, which we’d normally write as

enumerate :: [α] → [(ℤ, α)]
enumerate xs = zipwith (,) [0..] xs

enumerate pairs each element of a list with its index in the list; it’s often handy We can write this as a stream function:

enumerateₛ (Stream { next: iter, start: s₀ }) =
  Stream { next: go, start: (s₀, 0) }
    go (s, n) =
      case iter s of
	Yield x s' → Yield (x, n) (s', n+1)
	Skip s' → Skip (s', n+1) -- should be (s', n)
	Done → Done
enumerate = stream ∘ enumerateₛ ∘ unstream

There’s a bug in the Skip case, where we increment the index even though we have not seen a new element. Thanks to this bug, we have

enumerate $ filter even $ [1..7] ⇒ [(0, 2), (1, 4), (2, 6)]

without the rewrite rule, but as soon as we enable optimizations,

enumerate $ filter even $ [1..7] ⇒ [(1, 2), (3, 4), (5, 6)]

Note that the indices now track in the index in the original list, not in the intermediate list that has been optimized away.

Of course, this pitfall has been carefully avoided by the authors of the Haskell stream fusion libraries. But if you’re writing your own, watch out: this bug is easy to miss and at times easy to make. I’d write a QuickCheck property that all of your stream functions really do respect stream (unstream s) = s.


By G. Mainland, R. Leshchinskiy, and S. Peyton Jones.

By on . Share it—it's CC-BY-SA licensed.

Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author and do not necessarily reflect the views of the National Science Foundation.