Pages

05 December 2013

`runState` for a scalaz-stream Process

I was preparing to post this on the scalaz mailing-list but I thought that a short blog post could serve as a reference for other people as well. The following assumes that you have a good knowledge of Scalaz (at least of what's covered in my "Essence of the Iterator Pattern" post and some familiarity with the scalaz-stream library.

My use case

What I want to do is very common, just process a bunch of files! More precisely I want to (this is slightly simplified):

  1. read some pipe delimited files

  2. validate that the files have the proper internal structure:
    one(HEADER marker)
    one(column names)
    many(lines of pipe delimited values)
    one(TRAILER marker with total number of lines since the header)

  3. output only the lines which are not markers to another file

Scalaz stream

The excellent chapter 15 of Functional Programming in Scala highlights some of the potential problems with processing files:

  • you need to make sure you are closing resources properly even in the face of exceptions
  • you want to be able to easily compose small processing functions together instead of having a gigantic loop and a bunch of variables
  • you want to control the amount of data that is in memory at any moment in time

Based on the ideas of the book, Paul Chiusano created scalaz-stream, a library providing lots of combinators for doing this kind of input/output streaming operations (and more!).

A state machine for the job

My starting point for addressing our requirements is to devise a State object representing the both the expected file structure and the fact that some lines need to be filtered out. First of all I need to model the kind of lines I'm expecting when reading the file:

sealed trait LineState

case object ExpectHeader                            extends LineState
case object ExpectHeaderColumns                     extends LineState
case class  ExpectLineOrTrailer(lineCount: Int = 0) extends LineState

As you can see ExpectLineOrTrailer contains a counter to keep track of the number of lines seen so far.

Then I need a method (referred as the State function below) to update this state when reading a new line:

def lineState(line: String): State[Throwable \/ LineState, Option[String]] =
  State { state: Throwable \/ LineState =>
    def t(message: String) = new Exception(message).left

    (state, line) match {
      case (\/-(ExpectHeader), HeaderLine(_))           =>
        (ExpectHeaderColumns.right, None)
      case (\/-(ExpectHeaderColumns), _)                =>
        (ExpectLineOrTrailer(0).right, None)
      case (\/-(ExpectHeader), _)                       =>
        (t("expecting a header"), None)
      case (\/-(ExpectLineOrTrailer(n)), HeaderLine(_)) =>
        (t("expecting a line or a trailer"), None)
      case (\/-(ExpectLineOrTrailer(n)), TrailerLine(e)) =>
        if (n == e) (ExpectHeader.right, None)
        else        (t(s"wrong number of lines, expecting $e, got $n"), None)
      case (\/-(ExpectLineOrTrailer(n)), _)             =>
        (ExpectLineOrTrailer(n + 1).right, Some(line))
      case (-\/(e), _)                                  =>
        (state, None)
  }
}

The S type parameter (in the State[S, A] type) used to keep track of the "state" is Throwable \/ LineState. I'm using the "Left" part of the disjunction to represent processing errors. The error type itself is a Throwable. Originally I was using any type E but we'll see further down why I had to use exceptions. The value type A I extract from State[S, A] is going to be Option[String] in order to output None when I encounter a marker line.

This is all pretty good, functional and testable. But how can I use this state machine with a scalaz-stream Process?

runState

After much head scratching and a little help from the mailing-list (thanks Pavel!) I realized that I had to write a new driver for a Process. Something which would understand what to do with a State. Here is what I came up with:

def runState[F[_], O, S, E <: Throwable, A](p: Process[F, O])
                                           (f: O => State[E \/ S, Option[A]], initial: S)
                                           (implicit m: Monad[F], c: Catchable[F]) = {

  def go(cur: Process[F, O], init: S): F[Process[F, A]] = {
    cur match {
      case Halt(End) => m.point(Halt(End))
      case Halt(e)   => m.point(Halt(e))

      case Emit(h: Seq[O], t: Process[F, O]) => {
        println("emitting lines here!")
        val state = h.toList.traverseS(f)
        val (newState, result) = state.run(init.right)
        newState.fold (
          l => m.point(fail(l)),
          r => go(t, r).map(emitAll(result.toSeq.flatten) ++ _)
        )
      }

      case Await(req, recv, fb: Process[F, O], cl: Process[F, O]) =>
        m.bind (c.attempt(req.asInstanceOf[F[Any]])) { _.fold(
        { case End => go(fb, init)
          case e   => go(cl.causedBy(e), init) },
        o => go(recv.asInstanceOf[Any => Process[F ,O]](o), init)) }
    }
  }
  go(p, initial)
}

This deserves some comments :-)

The idea is to recursively analyse what kind of Process we're currently dealing with:

  1. if this is a Halt(End) we've terminated processing with no errors. We then return an empty Seq() in the context of F (hence the m.point operation). F is the monad that provides us input values so we can think of all the computations happening here as happening inside F (probably a scalaz.concurrent.Task when reading file lines)

  2. if this is a Halt(error) we use the Catchable instance for F to instruct the input process what to do in the case of an error (probably close the file, clean up resources,...)

  3. if this is an Emit(values, rest) we traverseS the list of values in memory with our State function and we use the initial value to get: 1. the state at the end of the traversal, 2. all the values returned by our State at each step of its execution. Note that the traversal will happen on all the values in memory, there won't be any short-circuiting if the State indicates an error. Also, this is important, the traverseS method is not trampolined. This means that we will get StackOverflow exceptions if the "chunks" that we are processing are too big. On the other hand we will avoid trampolining on each line so we should get good performances. If there was an error we stop all processing and return the error otherwise we emit all the values collected by the State appended to a recursive call to go

  4. if this is an Await Process we attempt to read input values, with c.attempt, and use the recv function to process them. We can do that "inside the F monad" by using the bind (or flatMap) method. The resulting Process is sent to go in order to be processed with the State function

Note what we do in case 2. when the newState returns an exception.left. We create a Process.fail process with the exception. This is why I used a Throwable to represent errors in the State function.

Now let's see how to use this new "driver".

Let's use it

First of all, we create a test file:

import scalaz.stream._
import Process._

val lines = """|HEADER|file
               |header1|header2
               |val11|val12
               |val21|val22
               |val21|val22
               |TRAILER|3""".stripMargin

// save 100 times the lines above in a file
(fill(100)(lines).intersperse("\n").pipe(process1.utf8Encode)
  .to(io.fileChunkW("target/file.dat")).run.run

Then we read the file but we buffer 50 lines at the time to control our memory usage:

val lines = io.linesR("target/file.dat").buffer(50)

We're now ready to run the state function:

// this task processes the lines with our State function
// the initial State is `ExpectHeader` because this is what we expect the first line to be
val stateTask: Task[Process[Task, String]] = runState(lines)(lineState, ExpectHeader)

// this one outputs the lines to a result file
// separating each line with a new line and encoding it in UTF-8
val outputTask: Task[Unit] = stateTask.flatMap(_.intersperse("\n").pipe(process1.utf8Encode)
                                      .to(io.fileChunkW("target/result.dat")).run)

// if the processing throws an Exception it will be retrieved here
val result: Throwable \/ Unit = task.attemptRun

When we finally run the Task, the result is either ().right if we were able to read, process, and write back to disc or exception.left if there was any error in the meantime, including when checking if the file has a valid structure.

The really cool thing about all of this is that we can now precisely control the amount of memory consumed during our processing by using the buffer method. In the example above we buffer 50 lines at the time then we process them in memory using traverseS. This is why I left a println statement in the runState method. I wanted to see "with my own eyes" how buffering was working. We could probably load more lines but the trade-off will then be that the stack that is consumed by traverseS will grow and that we might face StackOverflow exceptions.

I haven't done yet any benchmark but I can imagine lots of different ways to optimise the whole thing for our use case.

try { blog } finally { closing remarks }

I'm only scratching the surface of the scalaz-stream library and there is still a big possibility that I completely misunderstood something obvious!

First, it is important to say that you might not need to implement the runState method if you don't have complex validation requirements. There are 2 methods, chunkBy and chunkBy2, which allow to create "chunks" of lines based on a given line (for chunk) or pair of lines (for chunk2) naturally serving as "block" delimiters in the read file (for example a pair of "HEADER" followed by a "TRAILER" in my file).

Second, it is not yet obvious to me if I should use ++ or fby when I'm emitting state-processed lines + "the rest" (in step 2 when doing: emitAll(result.toSeq.flatten) ++ _). The difference has to do with error/termination management (the fallback process of Await) and I'm still unclear on how/when to use this.

Finally I would say that the scalaz-stream library is intriguing in terms of types. A process is Process[F[_], O] where O is the type of the output and the type of the input is... nowhere? Actually it is in the Await[F[_], A, O] constructor as a forall type. That's not all. In Await you have the type of request, F[A], a function to process elements of type A: recv: A => Process[F, O] but no way to extract or map the value A from the request to pass it to the recv method! The only way to do that is to provide an additional constraint to the "driver method" by saying, for example, that there is an implicit Monad[F] somewhere. This is the first time that I see a design where we build structures and then we give them properties when we want to use them. Very unusual.

I hope this can help other people exploring the library and, who knows, some of this might end up being part of it. Let's see what Paul and others think...

No comments: