Error and state handling with monad transformers in Scala

In this post I will look at a practical example where the combined application (through monad transformers) of the state monad and the either monad can be very useful.

I won't go into much theory, but instead demonstrate the problem and then slowly build it up to resolve it.

You don't have to be completely familiar with all the concepts as the examples will be easy to follow. Here is a very brief overview:

  • State monad: Data type for modeling mutable state in a purely functional way
  • Either monad: Data type that contains either of two values that represent either a success or a failure
  • Monad transformers: Type constructors to combine multiple monads into one

The complete source code can be found on GitHub.

The problem

Consider a function that extracts the schema of a JSON value. There are two things to note:

  • Every JSON object schema must be given a unique Id (needed for further processing)
  • There are cases when processing can fail (e.g. with null values or empty arrays)
object SchemaExtractor {

  def fromJsRoot(value: JsValue): Schema = {
    value match {
      case JsObject(fields) =>
        fromJsObject("root", fields.toList)
      case _ =>
        ???
    }
  }

  private def fromJsObject(name: ClassName, fields: List[(String, JsValue)]): Schema = {
    val schemaFields =
      fields.map { case (fieldName, value) =>
        fromJsValue(fieldName, value)
      }

    SchemaObject(
      UUID.randomUUID,
      name,
      schemaFields)
  }

  private def fromJsValue(name: String, value: JsValue): (String, Schema) = {
    val schema = value match {
      case JsNull =>
        ???
      case JsString(_) =>
        SchemaString
      case JsNumber(_) =>
        SchemaDouble
      case JsBoolean(_) =>
        SchemaBoolean
      case JsObject(fields) =>
        fromJsObject(name, fields.toList)
      case JsArray(values) =>
        val schemas =
          values.map(v => fromJsValue(name, v)._2)
        val first = schemas.head
        if (schemas forall Schema.haveSameStructure(first)) {
          SchemaArray(first)
        } else {
          ???
        }
    }

    (name, schema)
  }
}

The function is impure. Leaving aside the error cases for now, the Id generation by calling UUID.randomUUID introduces a side effect that makes the function non-deterministic. Referential transparency demands that for a given input a function must always yield that same output. This doesn't hold here.

One of the consequences is that testing becomes a tedious task.

test("test impure schema extractor") {
  val json =
    """{
      |  "value1": 42,
      |  "value2": { "value2.1": "", "value2.2": 5, "value2.3": { "value2.3.1": 1.0, "value2.3.2": 1.0, "value2.3.3": 1.0 } },
      |  "value3": { "value3.1": true, "value3.2": 2 },
      |  "value4": [
      |     [ { "value4.1": "", "value4.2": 123 }, { "value4.1": "", "value4.2": 345 } ],
      |     [ { "value4.1": "", "value4.2": 678 }, { "value4.1": "", "value4.2": 312 } ]
      |  ]
      |}
    """.stripMargin

  val schema = SchemaExtractor.fromJsRoot(Json.parse(json))

  schema match {
    case SchemaObject(_, name, fields) =>
      assert(name == "root")
      assert(fields.map(_._1) == List("value1", "value2", "value3", "value4"))
      // ...
    case _ => fail()
  }
}

We cannot simply test the expected and actual result for structural equality with ==. Instead we have to manually deconstruct the result.

If we could make the Id generation pure, we would get a lot of benefit, one of which would be better testability.

The state monad

One possible solution is to use the state monad.

The state monad is represented by a data type that wraps a function of type S => (S, A). The function takes a state of type S and yields a new state of type S and a value of type A. In Cats e.g. this type is called State and also provides many useful operations for composing and transforming state.

To get a better understanding of how to apply this to the original code, let's look at a simpler example of labeling leafs of a tree structure (similar to this example). Here is the definition of the tree data type:

sealed trait Tree[+A]
case class Branch[+A](left: Tree[A], right: Tree[A]) extends Tree[A]
case class Leaf[A](a: A) extends Tree[A]

Suppose we wanted to transform a tree of type Tree[A] to a labeled tree of type type LabeledTree[A] = Tree[(Int, A)].

This can be done like this:

def fromTree[A](tree: Tree[A]): State[Int, LabeledTree[A]] = {
  tree match {
    case Leaf(a) =>
      for {
        state <- State.get[Int]
        _ <- State.modify[Int](s => s + 1)
      } yield Leaf(state, a)
    case Branch(left, right) =>
      for {
        l <- fromTree(left)
        r <- fromTree(right)
      } yield Branch(l, r)
  }
}

The for comprehension is syntax sugar for the flatMap operation which allows us to sequence implicit state changes.

Let's run this code:

val tree: Tree[Char] = Branch(Leaf('a'), Branch(Branch(Leaf('b'), Leaf('c')), Leaf('d')))
pprintln(tree)
pprintln(LabeledTreeExampleWithState.LabeledTree.fromTree(tree).runA(0).value)

Output:

Branch(Leaf('a'), Branch(Branch(Leaf('b'), Leaf('c')), Leaf('d')))
Branch(Leaf((0, 'a')), Branch(Branch(Leaf((1, 'b')), Leaf((2, 'c'))), Leaf((3, 'd'))))

Schema extraction with the state monad

First we need a pure function of type S => S, that defines how a new Id can be produced from the previous one. In our case the schema Id is of type UUID and can be computed like this:

type SchemaId = UUID
def nextId: SchemaId => SchemaId = id => UUID.nameUUIDFromBytes(id.toString.getBytes)

In the original implementation we needed to generate UUID's to ensure uniqueness. But the new mechanism for generating unique values can be applied to almost any other type. Do we need all this overhead of the UUID type? No, we shouldn't needlessly tie our code to this cumbersome representation of an Id if we could simply use an Int:

type SchemaId = Int
def nextId: SchemaId => SchemaId = _ + 1

Now we can change the original implementation according to the example of the labeled tree:

object SchemaExtractor {

  def fromJsRoot(value: JsValue): State[SchemaId, Schema] = {
    value match {
      case JsObject(fields) =>
        fromJsField("root", fields.toList)
      case _ =>
        ???
    }
  }

  private def fromJsField(name: ClassName, fields: List[(String, JsValue)]): State[SchemaId, Schema] = {
    val schemaFieldsState =
      fields.map { case (fieldName, value) =>
        fromJsValue(fieldName, value)
      }

    for {
      schemaFields <- schemaFieldsState.sequence
      state <- get[SchemaId]
      _ <- modify(Schema.nextId)
    } yield
      SchemaObject(
        state,
        name,
        schemaFields)
  }

  private def fromJsValue(name: String, value: JsValue): State[SchemaId, (String, Schema)] = {
    val schema = value match {
      case JsNull =>
        ???
      case JsString(_) =>
        pure[SchemaId, Schema](SchemaString)
      case JsNumber(_) =>
        pure[SchemaId, Schema](SchemaDouble)
      case JsBoolean(_) =>
        pure[SchemaId, Schema](SchemaBoolean)
      case JsObject(fields) =>
        fromJsField(name, fields.toList)
      case JsArray(values) =>
        fromJsArray(name, values.toList)
    }

    schema.map((name, _))
  }

  private def fromJsArray(name: String, values: List[JsValue]): State[SchemaId, Schema] = {
    for {
      schemas <- values.map(v => fromJsValue(name, v).map(_._2)).sequence
      first = schemas.head
    } yield
      if (schemas forall Schema.haveSameStructure(first)) {
        SchemaArray(first)
      } else {
        ???
      }
  }

}

There are a few things to note here.

  • We use map to map over a value of type State[S, A], e.g. schema.map((name, _)).
  • We use pure to create instances of State[S, A] from a value, e.g. pure[SchemaId, Schema](SchemaString)
  • We use sequence to transform a List[State[S, A]] to a State[S, List[A]], e.g. schemaFields <- schemaFieldsState.sequence

Testing is much easier now

With runA we supply an initial state to the result of the schema extractor. As a result we get an instance of Eval (to maintain stack safety) and we can call value to get the final result.

val actual = SchemaExtractor.fromJsRoot(Json.parse(json)).runA(0).value

In the test we can now simply check for equality instead of deconstructing the complete result:

// ...
val value4 = SchemaArray(SchemaArray(SchemaObject(3, "value4", List(("value4.1", SchemaString), ("value4.2", SchemaDouble)))))
val value3 = SchemaObject(2, "value3", List(("value3.1", SchemaBoolean), ("value3.2", SchemaDouble)))
val value2 = SchemaObject(1, "value2", List(
  ("value2.1", SchemaString),
  ("value2.2", SchemaDouble),
  ("value2.3", SchemaObject(0, "value2.3", List(
    ("value2.3.1", SchemaDouble),
    ("value2.3.2", SchemaDouble),
    ("value2.3.3", SchemaDouble)))
  )
))

val expected = SchemaObject(7, "root", List(
  ("value1", SchemaDouble),
  ("value2", value2),
  ("value3", value3),
  ("value4", value4)))

assert(actual == expected)

Error handling and monad transformers

Now it's time to apply error handling with the help of Either.

However, if we would simply use a type like State[SchemaId, Either[Error, Schema]] we would have to write a lot of messy, nested for-comprehensions.

Here monad transformers come to the rescue because they allow us to combine the behavior of multiple monads into one.

Let's again look at the tree example to get a better understanding of how to use the StateT type to stack another monad inside.

StateT takes an additional type constructor as type parameter, but we cannot simply pass in Either because it must be a constructor that takes only a single parameter. Therefore, we have to define a type alias with a fixed error type:

type Error = String
type ErrorOr[A] = Either[Error, A]

Now we can define aliases for the combined state and either monad:

type EitherState[A] = StateT[ErrorOr, Int, A]

Next we make a few changes to the fromTree function:

def fromTree[A](tree: Tree[A]): EitherState[LabeledTree[A]] = {
  tree match {
    case Leaf(a) =>
      for {
        state <- StateT.get[ErrorOr, Int]
        _ <- StateT.modify[ErrorOr, Int](s => s + 1)
      } yield Leaf(state, a): LabeledTree[A]
    case Branch(left, right) =>
      for {
        l <- fromTree(left)
        r <- fromTree(right)
      } yield Branch(l, r)
  }
}

Everything stayed pretty much the same. Only get and modify have a slightly different type, as they are called on StateT as opposed to State now.

Calling runA with an initial state of 0 will output:

Right(Branch(Leaf((0, 'a')), Branch(Branch(Leaf((1, 'b')), Leaf((2, 'c'))), Leaf((3, 'd')))))

Handle invalid JSON inputs

Let's apply this to the schema extractor, and first define the type aliases:

type Error = String
type ErrorOr[A] = Either[Error, A]
type EitherState[A] = StateT[ErrorOr, SchemaId, A]

In fromJsonRoot we lift a value of type Left[A] into StateT with lift:

def fromJsRoot(value: JsValue): EitherState[Schema] = {
  value match {
    case JsObject(fields) =>
      fromJsField("root", fields.toList)
    case _ =>
      lift[ErrorOr, SchemaId, Schema](Left("JSON root value must be an object"))
  }
}

In the Id generation part only get and modify are changed:

private def fromJsField(name: ClassName, fields: List[(String, JsValue)]): EitherState[Schema] = {
  val schemaFieldsState =
    fields.map { case (fieldName, value) =>
      fromJsValue(fieldName, value)
    }

  for {
    schemaFields <- schemaFieldsState.sequence
    state <- get[ErrorOr, SchemaId]
    _ <- modify[ErrorOr, SchemaId](Schema.nextId)
  } yield
    SchemaObject(
      state,
      name,
      schemaFields)
}

In fromJsValue we create an instance of StateT representing an error with lift and instances representing a successful result with pure:

private def fromJsValue(name: String, value: JsValue): EitherState[(String, Schema)] = {
  val schema = value match {
    case JsNull =>
      lift[ErrorOr, SchemaId, Schema](Left("cannot analyze type of a JSON null value"))
    case JsString(_) =>
      pure[ErrorOr, SchemaId, Schema](SchemaString)
    // ...

We've covered everything that we need to change fromJsArray accordingly:

private def fromJsArray(name: String, values: List[JsValue]): EitherState[Schema] = {
  for {
    schemas <- values.map(v => fromJsValue(name, v).map(_._2)).sequence
    schema <- schemas.headOption match {
      case Some(first) if schemas forall Schema.haveSameStructure(first) =>
        pure[ErrorOr, SchemaId, Schema](SchemaArray(first))
      case None =>
        lift[ErrorOr, SchemaId, Schema](Left("cannot analyze empty JSON array"))
      case _ =>
        lift[ErrorOr, SchemaId, Schema](Left("array type is not consistent"))
    }
  } yield schema
}

We can easily test the error cases like this:

test("test null value") {
  val json ="""{ "x": null }"""

  val actual = SchemaExtractor.fromJsRoot(Json.parse(json)).runA(0)

  assert(actual == Left("cannot analyze type of a JSON null value"))
}

Now the schema extraction function is completely pure, total and deterministic.

The complete source code can be found on GitHub.