Monday, March 09, 2009

Parallel merge sort in Erlang

I've been thinking lately about the problem of scaling a service like Twitter or the Facebook news feed. When a user visits the site, you want to show her a list of all the recent updates from her friends, sorted by date. It's easy when the user doesn't have too many friends and all the updates are on a single database (as in Twoorl's case :P). You use this query:

"select * from update where uid in ([fid1], [fid2], ...) order by creation_date desc limit 20"

(After making sure you created an index on uid and creation_date, of course :) )

However, what do you when the user has many thousands of friends, and each friend's updates are stored on a different database? Clearly, you should fetch those updates in parallel. In Erlang, it's easy. You use pmap():

fetch_updates(Uids) ->
fun(Uid) ->
Db = get_db_for_user(Uid),
query(Db, [<<"select * from update where uid =">>,
Uid, <<" order by creation_date desc limit 20">>])
end, Uids).

%% Applies the function Fun to each element of the list in parallel
pmap(Fun, List) ->
Parent = self(),
%% spawn the processes
Refs =
fun(Elem) ->
Ref = make_ref(),
fun() ->
Parent ! {Ref, Fun(Elem)}
end, List),

%% collect the results
fun(Ref) ->
{Ref, Elem} ->
end, Refs).

Getting the updates is straightforward. However, what do you do once you've got them? Merging thousands of lists can take a long time, especially if you do it in a single process. The last thing you want is that your site's performance would grind to a halt when users add lots of friends.

Fortunately, merging a list of lists isn't too hard to do in parallel. Once you've implemented your nifty parallel merge algorithm, you can theoretically speed up response time by adding more cores to your web servers. This should help you maintain low latency even for very dense social graphs.

So, how do you merge a list of sorted lists in parallel in Erlang? There is probably more than one way of doing it, but this is what I came up with: you create a list of single element lists. You scan through the main list, and for each pair of lists you spawn a process that merges the two lists and sends the result to the parent process. The parent process collects all the results, and repeats as longs as there is more than one result. When only one result is left, the parent returns it.

Let's start with the base case of how to merge two lists:

%% Merges two sorted lists
merge(L1, L2) -> merge(L1, L2, []).

merge(L1, [], Acc) -> lists:reverse(Acc) ++ L1;
merge([], L2, Acc) -> lists:reverse(Acc) ++ L2;
merge(L1 = [Hd1 | Tl1], L2 = [Hd2 | Tl2], Acc) ->
{Hd, L11, L21} =
if Hd1 < Hd2 ->
{Hd1, Tl1, L2};
true ->
{Hd2, L1, Tl2}
merge(L11, L21, [Hd | Acc]).

Now, to the more interesting part: how to merge a list of sorted lists in parallel.

%% Merges all the lists in parallel
merge_all(Lists) ->
merge_all(Lists, 0).

%% When there are no lists to collect or to merge, return an
%% empty list.
merge_all([], 0) ->

%% When no lists are left to merge, we collect the results of
%% all the merges that were done in spawned processes
%% and recursively merge them.
merge_all([], N) ->
Lists = collect(N, []),
merge_all(Lists, 0);

%% If only one list remains, merge it with the result
%% of all the pair-wise merges
merge_all([L], N) ->
merge(L, merge_all([], N));

%% If two or more lists remains, spawn a process to merge
%% the first two lists and move on to the remaining lists
%% without blocking. Also, increment the number
%% of spawned processes so we know how many results
%% to collect later.
merge_all([L1, L2 | Tl], N) ->
Parent = self(),
fun() ->
Res = merge(L1, L2),
Parent ! Res
merge_all(Tl, N + 1).

%% Collects the results of N merges (the order
%% doesn't matter).
collect(0, Acc) -> Acc;
collect(N, Acc) ->
L = receive
Res -> Res
collect(N - 1, [L | Acc]).

So, how well does this perform? I ran a benchmark on my 2.5 GHz Core 2 Duo Macbook Pro. First, I created a list of a million random numbers, each between 1 and a million:

> L = [random:uniform(1000000) || N <- lists:seq(1, 1000000)].

Then, I timed how long it takes to sort the list, first with lists:sort() and then with my shiny new parallel merge function.

> timer:tc(lists, sort, [L]).

Less than a second. lists:sort() is pretty fast!

Before we can pass the list of numbers into merge_all(), we have to break it up into multiple lists with a single element in each list:

> Lists = [[E] || E <- L].

Now for the moment of truth:

> timer:tc(psort, merge_all, [Lists]).

About 8.2 seconds :(

It's not exactly an improvement, but at least we learned something. In this test case, the overhead of process spawning and inter-process communications outweighed the benefits of parallelism. It would be interesting to run the same test it on machines that have more than two cores but I don't have any at my disposal right now.

Another factor to consider is that lists:sort() is AFAIK implemented in C and therefore it has an unfair advantage over a function implemented in pure Erlang. Indeed, I tried sorting the list with the following pure Erlang quicksort function:

qsort([]) -> [];
qsort([H]) -> [H];
qsort([H | T]) ->
qsort([E || E <- T, E =< H]) ++
[H] ++
qsort([E || E <- T, E > H]).

> timer:tc(psort, qsort, [L]).

It took about ~2 seconds to sort the million numbers.

The performance of merge_all() doesn't seem great, but consider that we spawned ~1,000,000 processes during this test. It had ~19 levels of recursion (log2 500,000). At each level, we spawned half the number of processes as the previous level. The sum of all levels is 500,000*(1 + 1/2 + 1/4 + 1/8 ... + 1/19) ~= 1,000,000 ( 8 seconds / 500,000 processes = 0.000016 seconds / process. It's actually quite impressive!

Let's go back to the original problem. It wasn't to sort one big list, but to merge a list of sorted lists with 20 items in each list. In this scenario, we still benefit from parallelism but we don't pay for the overhead of spawning hundreds of thousands of processes to merge tiny lists in the first few levels of recursion. Let's see how long it takes merge_all() to merge a million random numbers split between 50,000 sorted lists.

> Lists = [lists:sort([random:uniform(1000000) || N <- lists:seq(1, 20)])
|| N1 <- lists:seq(1, 50000)].
> timer:tc(psort, merge_all, [lists]).

This function call took just over 2 seconds to run, roughly the same time as qsort(), yet it involved spawning 25,000*(1 - 0.5^15)/(1 - 0.5) ~= 50,000 processes! Now the benefits of concurrency start being more obvious.

Can you think of ways to improve performance further? Let me know!