A small benchmark util for Scala

Specifying the benchmark

Note: there’s now also a follow-up post which extends this benchmark.

This post was inspired by an answer to the Hidden features of Scala question over at StackOverflow.com (and also another blog post I cannot find back to). I expanded a little on the code that was posted, which resulted in the following neat lump of code.

case class TaskDone[Id](id: Id, run: Int, duration: Long) case class TaskFail[Id](id: Id, run: Int, duration: Long, error: Throwable) case class BenchDone[Id](warmups: Seq[Task[Id]], runs: Seq[Task[Id]]) type Report[Result] = Result => Unit type Task[Id] = Either[TaskFail[Id],TaskDone[Id]] type Bench[Id] = BenchDone[Id] type Batch[Id] = Seq[Bench[Id]] def task[Id](id: Id, fn: () => Unit)(run: Int)(report: Report[Task[Id]]): Task[Id] = { val begin = System.nanoTime val result = try { fn() Right(TaskDone(id=id,run=run,duration=System.nanoTime - begin)) } catch { case exp: Throwable => Left(TaskFail(id=id,run=run,duration=System.nanoTime - begin,error=exp)) } report(result) result } def benchmark[Id](task: Int => Report[Task[Id]] => Task[Id])(runs: Int, warmups: Int) (warmupReport: Report[Task[Id]], runReport: Report[Task[Id]]) (benchReport: Report[Bench[Id]]): Bench[Id] = { assert(runs > 0, "Number of runs must be greater than zero.") assert(warmups >= 0, "Number of warmups must be zero or greater.") val (warmupsResults,runsResults) = ( (1 to warmups) map (run => task(run)(warmupReport)), (1 to runs) map (run => task(run)(runReport)) ) val result = BenchDone(warmups=warmupsResults,runs=runsResults) benchReport(result) result } def batch[Id](benchs: Seq[Report[Bench[Id]] => Bench[Id]]) (benchReport: Report[Bench[Id]])(batchReport: Report[Batch[Id]]) { val result = benchs map (bench => bench(benchReport)) batchReport(result) }
Code language: Scala (scala)

What this code does is create a task() with an id and a function of () => Unit to apply. This task is then added to a benchmark(), which is further specified with how many runs and warmups to perform, and a set of task reporters which handles individual results. Finally, a batch() is constructed and given a set of benchmarks and a benchmark reporter. When this batch is completed with a batch reporter, it will trigger the benchmarks to run, and subsequently the individual tasks. The benchmark results are then accumulated and fed back into the batch reporter.

The following diagram gives visual sketch of the code. Note that half-open endpoints signify arguments already specified from outside all the enclosing functions.

Example usage

Before specifying the tasks, lets first define the reporters that will handle the results. This is where the largest bulk of code resides.

For tasks we can define two kinds of reporters. One for warmup runs and the other for ordinary runs. Each reporter handles a result in form of an Either[TaskFail[Id],TaskDone[Id]] which it can pattern matched against. A Right[TaskDone[Id]] result indicates a successful run with data on id, run number and duration. If not, there is a Left[TaskFail[Id]] result that indicates a failed run with data on id, run number, duration, and an exception. A benchmark reporter is also defined which handles the accumulated results of warmups runs and ordinary runs. The resulting data from all benchmarks is then finally fed into a batch reporter.

In the following code warmupReport() and runReport() specifies the task reporters, benchReport() handles the accumulated runs, and batchReport() presents the final statistics.

type Id = String def timeToText(value: Long): String = { // utility for turning nanoseconds into text val ms = (value / 1000000L).toLong val msPart = ms % 1000L val sec = (ms - msPart) / 1000L val secPart = sec % 60L val min = (sec - secPart) / 60L val minPart = min if (minPart > 0) minPart + "m" + secPart + "s" + msPart + "ms" else if (secPart > 0) secPart + "s" + msPart + "ms" else msPart + "ms" } def taskId[Id](bench: Bench[Id]): Id = { // utility for extracting task id val firstDone = bench.runs find (result => result.isRight) val firstFail = bench.runs find (result => result.isLeft) if (firstDone.nonEmpty) firstDone.get.right.get.id else firstFail.get.left.get.id } def isDone[Id](bench: Bench[Id]): Boolean = { // utility for detecting a bench without errors val warmupsFail = bench.warmups collect { case Left(task) => task } val runsFail = bench.runs collect { case Left(task) => task } warmupsFail.isEmpty && runsFail.isEmpty } def avgTime(bench: Bench[Id]): Long = { // utility for calculating average time of runs in a benchmark val totals = bench.runs.collect{ case Right(task) => task.duration case Left(task) => task.duration } if (totals.nonEmpty) totals.sum / totals.length else 0 } def taskReport(title: String)(result: Task[Id]) { result match { case Right(task) => println( task.id.toString + " " + title + " " + task.run + " completed in " + timeToText(task.duration) ) case Left(task) => println( task.id.toString + " " + title + " " + task.run + " failed after " + timeToText(task.duration) + " from " + task.error.getClass.getName ) } } val warmupReport = taskReport("warmup") _ val runReport = taskReport("run") _ def benchReport[Id](bench: Bench[Id]) { val id = taskId(bench) if (isDone(bench)) { val totalTime = bench.runs.collect{ case Right(task) => task.duration }.sum println(id + " finished in " + timeToText(totalTime) + "\n") } else println(id + " failed\n") } def batchReport(batch: Batch[Id]) { val (doneBenchs,failedBenchs) = batch partition isDone println( "Batch of benchmarks finished with %s completed and %s failed.\n" .format(doneBenchs.length,failedBenchs.length) ) println("Average times for benchmarks:") println( doneBenchs .map(bench => (taskId(bench),avgTime(bench))) .sortWith((a,b) => a._2 < b._2) // sort on durations .map(item => "%10s %s".format(timeToText(item._2),item._1)) .mkString("\n") ) println( failedBenchs .map(bench => "%10s %s".format("na",taskId(bench))) .mkString("\n") ) }
Code language: Scala (scala)

With the reporters in place we can begin to define some tasks to do benchmarks on. We want to check how well the Scala collections library handles folding over ranges. A range in Scala is simply created by writing 1 to 100 which represents the list of numbers from 1 to 100. Now, this doesn’t actually create and allocate a whole list. The individual elements are instead lazily constructed. What we want to check is how this data structure behaves when we fold it to sum up its elements. We will do fold from left-to-right and from right-to-left. If there are any problems associated with ranges they will most likely show up with large enough ranges. For comparison, we will also create ordinary lists, vectors, arrays, and streams to compare against.

val (start,end) = (1,5000000) val sum = (a: Int, b: Int) => a + b val rangeTasks = List( task(id="range+sum", fn=() => (start to end).sum) _, task(id="range+foldLeft", fn=() => (start to end).foldLeft(0)(sum)) _, task(id="range+foldRight", fn=() => (start to end).foldRight(0)(sum)) _ ) val listTasks = List( task(id="list+sum", fn=() => List.range(start,end).sum) _, task(id="list+foldLeft", fn=() => List.range(start,end).foldLeft(0)(sum)) _, task(id="list+foldRight", fn=() => List.range(start,end).foldRight(0)(sum)) _ ) val vectorTasks = List( task(id="vector+sum", fn=() => Vector.range(start,end).sum) _, task(id="vector+foldLeft", fn=() => Vector.range(start,end).foldLeft(0)(sum)) _, task(id="vector+foldRight", fn=() => Vector.range(start,end).foldRight(0)(sum)) _ ) val arrayTasks = List( task(id="array+sum", fn=() => Array.range(start,end).sum) _, task(id="array+foldLeft", fn=() => Array.range(start,end).foldLeft(0)(sum)) _, task(id="array+foldRight", fn=() => Array.range(start,end).foldRight(0)(sum)) _ ) val streamTasks = List( task(id="stream+sum", fn=() => Stream.range(start,end).sum) _, task(id="stream+foldLeft", fn=() => Stream.range(start,end).foldLeft(0)(sum)) _, task(id="stream+foldRight", fn=() => Stream.range(start,end).foldRight(0)(sum)) _ ) val itrTasks = List( task(id="list+itr+foldRight", fn=() => List.range(start,end).iterator.foldRight(0)(sum)) _, task(id="stream+itr+sum", fn=() => Stream.range(start,end).iterator.sum) _, task(id="stream+itr+foldRight", fn=() => Stream.range(start,end).iterator.foldRight(0)(sum)) _ ) val benchs = ( rangeTasks ++ listTasks ++ vectorTasks ++ arrayTasks ++ streamTasks ++ itrTasks ).map( task => benchmark(task)(runs=2,warmups=1)(warmupReport,runReport) _ ) batch(benchs)(benchReport)(batchReport)
Code language: Scala (scala)

Running this code had the following results on my computer. It shows that foldRight() is having problems in most collections. Lists and streams seems to have stumbled upon some problems with the callstack. If the range is extended slightly further, it will also show that foldRight() on a range will run out of memory. Don’t get too worried though. If the number of elements are quite small, these constraints have little influence. Also, per Scala mailing list you can avoid this problem by turning your sequence into an iterator with seq.iterator.foldRight(), which will add the necessary optimizations to counter the lack of tail-call optimization in the JVM.

range+sum warmup 1 completed in 185ms range+sum run 1 completed in 139ms range+sum run 2 completed in 140ms range+sum finished in 279ms range+foldLeft warmup 1 completed in 138ms range+foldLeft run 1 completed in 140ms range+foldLeft run 2 completed in 140ms range+foldLeft finished in 280ms range+foldRight warmup 1 completed in 2s774ms range+foldRight run 1 completed in 1s280ms range+foldRight run 2 completed in 1s249ms range+foldRight finished in 2s529ms list+sum warmup 1 completed in 1s682ms list+sum run 1 completed in 1s179ms list+sum run 2 completed in 1s581ms list+sum finished in 2s761ms list+foldLeft warmup 1 completed in 1s428ms list+foldLeft run 1 completed in 1s615ms list+foldLeft run 2 completed in 1s478ms list+foldLeft finished in 3s94ms list+foldRight warmup 1 failed after 682ms from java.lang.StackOverflowError list+foldRight run 1 failed after 869ms from java.lang.StackOverflowError list+foldRight run 2 failed after 741ms from java.lang.StackOverflowError list+foldRight failed vector+sum warmup 1 completed in 792ms vector+sum run 1 completed in 945ms vector+sum run 2 completed in 783ms vector+sum finished in 1s728ms vector+foldLeft warmup 1 completed in 831ms vector+foldLeft run 1 completed in 908ms vector+foldLeft run 2 completed in 745ms vector+foldLeft finished in 1s654ms vector+foldRight warmup 1 completed in 2s67ms vector+foldRight run 1 completed in 1s888ms vector+foldRight run 2 completed in 2s51ms vector+foldRight finished in 3s939ms array+sum warmup 1 completed in 587ms array+sum run 1 completed in 303ms array+sum run 2 completed in 284ms array+sum finished in 588ms array+foldLeft warmup 1 completed in 268ms array+foldLeft run 1 completed in 269ms array+foldLeft run 2 completed in 328ms array+foldLeft finished in 597ms array+foldRight warmup 1 completed in 254ms array+foldRight run 1 completed in 267ms array+foldRight run 2 completed in 245ms array+foldRight finished in 512ms stream+sum warmup 1 failed after 10s30ms from java.lang.OutOfMemoryError stream+sum run 1 failed after 9s46ms from java.lang.OutOfMemoryError stream+sum run 2 failed after 9s271ms from java.lang.OutOfMemoryError stream+sum failed stream+foldLeft warmup 1 completed in 575ms stream+foldLeft run 1 completed in 606ms stream+foldLeft run 2 completed in 574ms stream+foldLeft finished in 1s181ms stream+foldRight warmup 1 failed after 1ms from java.lang.StackOverflowError stream+foldRight run 1 failed after 2ms from java.lang.StackOverflowError stream+foldRight run 2 failed after 2ms from java.lang.StackOverflowError stream+foldRight failed list+itr+foldRight warmup 1 completed in 2s925ms list+itr+foldRight run 1 completed in 2s812ms list+itr+foldRight run 2 completed in 3s19ms list+itr+foldRight finished in 5s831ms stream+itr+sum warmup 1 completed in 1s547ms stream+itr+sum run 1 completed in 1s41ms stream+itr+sum run 2 completed in 871ms stream+itr+sum finished in 1s912ms stream+itr+foldRight warmup 1 completed in 3s715ms stream+itr+foldRight run 1 completed in 3s438ms stream+itr+foldRight run 2 completed in 3s623ms stream+itr+foldRight finished in 7s62ms Batch of benchmarks finished with 15 completed and 3 failed. Average times for benchmarks: 139ms range+sum 140ms range+foldLeft 256ms array+foldRight 294ms array+sum 298ms array+foldLeft 590ms stream+foldLeft 827ms vector+foldLeft 864ms vector+sum 956ms stream+itr+sum 1s264ms range+foldRight 1s380ms list+sum 1s547ms list+foldLeft 1s969ms vector+foldRight 2s915ms list+itr+foldRight 3s531ms stream+itr+foldRight na list+foldRight na stream+sum na stream+foldRight
Code language: Bash (bash)

Leave a Reply

Your email address will not be published. Required fields are marked *