Hashrocket.com / blog

Large bg article jackc 20160210

Custom Aggregates in PostgreSQL

posted on and written by in

Image 100x100 jack christensen

Given a bank account of debit and credit transactions, what is the greatest balance the account ever had? Given a hotel with arrivals and departures, what is the greatest number of guests the hotel ever had? Both of these are examples of finding the greatest running total. Finding the greatest running total is a great exercise with which to explore some of the lesser known features of PostgreSQL such as window functions, custom aggregates, and C functions.

The Setup

All code can be found on Github.

For this example, we will use a simple data schema that only contains an amount column and id column to provide ordering.

create table entries(
  id serial primary key,
  amount float8 not null
);

We will use random() and generate_series() to insert 1,000,000 rows of test data. By calling setseed() before calling random() we can ensure that this code always produces the same data.

select setseed(0);

insert into entries(amount)
select (2000 * random()) - 1000
from generate_series(1, 1000000);

Running Total

To find the greatest running total, we first have to find the running total for every row. This can be easily done with a window function.

select
  id,
  amount,
  sum(amount) over (order by id asc) as running_total
from entries
order by id asc;
   id    |        amount         |   running_total
---------+-----------------------+--------------------
       1 |     -462.016298435628 |  -462.016298435628
       2 |      162.440904416144 |  -299.575394019485
       3 |     -820.292402990162 |  -1119.86779700965
       4 |     -866.230697371066 |  -1986.09849438071
       5 |      -495.30001822859 |   -2481.3985126093
       6 |      772.393747232854 |  -1709.00476537645
       7 |     -323.866365477443 |  -2032.87113085389
       8 |     -856.917716562748 |  -2889.78884741664
       9 |      285.323366522789 |  -2604.46548089385
      10 |     -867.916810326278 |  -3472.38229122013
-- snip --

The expression sum(amount) over (order by id asc) can be read as sum amount for all rows ordered by id ascending from the first row to the current row. See the window function tutorial in the PostgreSQL documentation if you need a primer on window functions.

Greatest Running Total

Now that we have the running total for every row it should be simple to use the max aggregate function to find the greatest running total.

select max(sum(amount) over (order by id asc))
from entries;

Unfortunately, we get an error:

ERROR:  aggregate function calls cannot contain window function calls
LINE 1: select max(sum(amount) over (order by id asc))

Instead we have use a subquery.

select max(running_total)
from (
  select sum(amount) over (order by id asc) as running_total
  from entries
) t;

Here is the result:

       max
------------------
 396271.274807863
(1 row)

Time: 643.848 ms

Not too bad, but there are two potential areas for improvement: query simplicity and speed. What we really want to do is this:

select greatest_running_total(amount order by id asc)
from entries;

Note the order by id asc in the aggregate. Because a greatest_running_total function would require its inputs to be ordered to be correct, it is vital that we include this clause.

Custom Aggregates

The greatest_running_total function doesn't exist, but PostgreSQL gives us the functionality to create our own aggregate functions. In this case, our greatest_running_total aggregate should accept float8 values and return a float8.

To create an aggregate function we first need a state transition function. This function will be called for each input row with the aggregate internal state and the current row value. The internal state needs to contain the current running total as well as the greatest running total. So we need a structure of two float8 values. Fortunately, PostgreSQL has the point type which is exactly what we need (a float8 array would also work, but point is simpler use and potentially faster).

The following state transition function is implemented in PL/pgSQL.

create function grt_sfunc(agg_state point, el float8)
returns point
immutable
language plpgsql
as $$
declare
  greatest_sum float8;
  current_sum float8;
begin
  current_sum := agg_state[0] + el;
  if agg_state[1] < current_sum then
    greatest_sum := current_sum;
  else
    greatest_sum := agg_state[1];
  end if;

  return point(current_sum, greatest_sum);
end;
$$;

The point agg_state is used as a 2-element, zero-based array. agg_state[0] is the current sum; agg_state[1] is the greatest sum the aggregate has seen. We simply add agg_state[0] and the current row value el to get the new current sum. The new greatest sum is the greater of the old greatest sum (agg_state[1]) and the new current sum. Finally, we return a new point value with the new current and greatest sums.

Because our aggregate's internal state is of type point and the output of our aggregate is float8, we need an aggregate final function that takes the final value of the aggregate's internal state and converts it to a float8.

create function grt_finalfunc(agg_state point)
returns float8
immutable
strict
language plpgsql
as $$
begin
  return agg_state[1];
end;
$$;

Lastly, we have to create the aggregate by providing the state transition function, internal aggregate state type, and the final function.

create aggregate greatest_running_total (float8)
(
    sfunc = grt_sfunc,
    stype = point,
    finalfunc = grt_finalfunc
);

Let's try our new function.

select greatest_running_total(amount order by id asc)
from entries;
 greatest_running_total
------------------------
       396271.274807863
(1 row)

Time: 3386.443 ms

Success! The new function returns the same result and the query is much simpler. Unfortunately, performance took a huge hit. It is now over 5x slower than before. Clearly, this is not acceptable.

Custom Aggregates in C

The majority of the computation is in the state transition function. What if we implement that in C? The code below is the same logic implemented in C. For more details on C extensions see the documentation.

#include "postgres.h"
#include "fmgr.h"
#include "utils/geo_decls.h"

"postgres.h" and "fmgr.h" are needed by all custom C functions. "utils/geo_decls.h" is needed to import the Point struct.

#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif

The PG_MODULE_MAGIC macro that ensures extension won't load against incompatible version of PostgreSQL.

PG_FUNCTION_INFO_V1(grt_sfunc);

Datum
grt_sfunc(PG_FUNCTION_ARGS)
{

PG_FUNCTION_INFO_V1 is a macro that specifies that a function will use the version 1 calling convention. Version 1 functions always have the same signature.

  Point *new_agg_state = (Point *) palloc(sizeof(Point));

A Point is a C struct behind the SQL point type. It has two fields x and y. These directly correspond to point[0] and point[1] in SQL.

palloc is a PostgreSQL provided function to allocate memory. PostgreSQL will ensure all memory allocated with palloc is released at an appropriate time.

  double el = PG_GETARG_FLOAT8(1);

Here we use the PostgreSQL provided macro PG_GETARG_FLOAT8 to extract a float8 from the second argument to the function (in C arrays are zero based so argument 1 is the second argument).

  bool isnull = PG_ARGISNULL(0);
  if(isnull) {
    new_agg_state->x = el;
    new_agg_state->y = el;
    PG_RETURN_POINT_P(new_agg_state);
  }

If argument 0 (agg_state) is null this is the first value provided to the aggregate. Return a new state with the current sum (x) and greatest sum (y) equal to that value. PG_ARGISNULL is a macro that evaluates to true if the argument is null. PG_RETURN_POINT_P is a macro the returns a point as the result of the function.

  Point *agg_state = PG_GETARG_POINT_P(0);

  new_agg_state->x = agg_state->x + el;
  if(new_agg_state->x > agg_state->y) {
    new_agg_state->y = new_agg_state->x;
  } else {
    new_agg_state->y = agg_state->y;
  }

  PG_RETURN_POINT_P(new_agg_state);
}

With the null state for the first row handled, it is just a simple translation of logic from PL/pgSQL: Compute the new running total (x), and update the greatest running total (y) if the new running total is larger than the old.

Installing the new function takes several steps.

First, the shared libary must be built and installed. The following Makefile will accomplish that task. Just run make install (PostgreSQL installed with homebrew on OSX should have all PostgreSQL dependencies installed, on Debian or Ubuntu install postgresql-server-dev-9.5).

MODULES = grt
PGXS := $(shell pg_config --pgxs)
include $(PGXS)

Next, we must create the function in SQL.

create function
  grt_sfunc( point, float8 )
returns
  point
as
  'grt.so', 'grt_sfunc'
language
  c
immutable;

The final function and aggregate creation are unchanged.

create function grt_finalfunc(agg_state point)
returns float8
immutable
language plpgsql
as $$
begin
  return agg_state[1];
end;
$$;

create aggregate greatest_running_total (float8)
(
    sfunc = grt_sfunc,
    stype = point,
    finalfunc = grt_finalfunc
);

Let's see what happens:

select greatest_running_total(amount order by id asc)
from entries;
 greatest_running_total
------------------------
       396271.274807863
(1 row)

Time: 825.365 ms

4x faster than the PL/pgSQL version, but still a bit slower than the subquery version.

Summary

PostgreSQL gives us many ways to tackle problems. In the greatest running total example, the initial solution with a subquery and window function is best.

However, sometimes a calculation may be exceedingly difficult without a custom aggregate. In these cases, a PL/pgSQL implementation may be ideal. However, PL/pgSQL performance can be lacking.

If a custom aggregate is necessary and performance of PL/pgSQL is insufficient, a C function may be a solution. But this step should not be undertaken lightly. A bug in a C extension can crash PostgreSQL and even corrupt data. The subtleties of C and the deployment and portability difficulties of custom C are costs that should only be paid when there is no reasonable alternative.

Posted in PostgreSQL and tagged with PostgreSQL