Fighting the skew in Spark

Since I started using Spark, shuffles and joins have become the bane of my life. We frequently host group therapy sessions at my company where we share Spark stories, and try to answer questions like: when shall I stop waiting for a task (that might never end) and unplug the chord on it? Is there a life after I lose an executor?

Recently I have been doing some joins that were taking too long, and every single time it would finish the vast majority of the tasks rapidly but spend most of the time (like 99% of the time) finishing the last few. If you ever experienced this, then what you have is a badly skewed join.

This is briefly explained in the Databricks talk in The talk only explains how to diagnose the problem, but does not put forward any way of fixing it, apart from “picking a different algorithm” or “restructuring the data”. So last time I had this problem, I came up with a solution that might or might not be the right thing to do for every case, but it did the job for me.

For the purpose of demonstration, we will create a skewed rdd and a smaller non-skewed rdd and do the join between them. I start by creating the large skewed rdd in a similar manner to the Databricks talk by doing:

from math import exp
from random import randint
from datetime import datetime

def count_elements(splitIndex, iterator):
    n = sum(1 for _ in iterator)
    yield (splitIndex, n)

def get_part_index(splitIndex, iterator):
    for it in iterator:
        yield (splitIndex, it)

num_parts = 16
# create the large skewed rdd
skewed_large_rdd = sc.parallelize(range(0,num_parts), num_parts).flatMap(lambda x: range(0, int(exp(x))))

print "skewed_large_rdd has %d partitions."%skewed_large_rdd.getNumPartitions()
print "The distribution of elements across partitions is: %s"
%str(skewed_large_rdd.mapPartitionsWithIndex(lambda ind, x: count_elements(ind, x)).take(num_parts))

# put it in (key, value) form
skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(lambda ind, x: get_part_index(ind, x)).cache()
skewed_large_rdd has 17 partitions.
The distribution of elements across partitions is: [(0, 1), (1, 2), (2, 7), (3, 20), (4, 54), (5, 148), (6, 403), (7, 1096), (8, 2980), (9, 8103), (10, 22026), (11, 59874), (12, 162754), (13, 442413), (14, 1202604), (15, 3269017)]

The rdd is made out of tuples (num_partition, number) where num_partition is the index of the partition and number is a number between 0 and exp(num_partition). This creates an rdd with an artificial skew, where each partition has an exponentially large number of items in it. So partition 0 has 1 item, partition 1 has floor(exp(1))=2 items, partition 2 has floor(exp(2))=7 items and so on.

We create a second dummy dataset which has keys 1 to num_parts and the value is ‘b’.

small_rdd = sc.parallelize(range(0,num_parts), num_parts).map(lambda x: (x, x)).cache()

Both rdds are cached in memory, now we go and do what we came for: the join.

t0 =
result = skewed_large_rdd.leftOuterJoin(small_rdd)
print "The direct join takes %s"%(str( - t0))
The direct join takes 0:01:46.539335

In this particular example, it took ~20 seconds to do 31 tasks out of 32 and the remaining time to finish off the last one. This is a symptom of a skew. We can diagnose the skew by looking at the Spark webUI and checking the time taken per task. Some tasks will take very little time, while others will straggle behind for a significantly longer time.


One quick fix for a skewed join is to simply drop the largest items in the rdd, which may contain outliers. However, let’s imagine we don’t want to throw away data. In that case, I came up with a semi-quick fix that deals with the skew at the expense of having some data replication. And it works like this:

  • We replicate the data in the small rdd N times by creating a new key (original_key, v) where v takes values between 0 and N. The value does not change, i.e. it is the same value that was associated to the original key.
  • We take the large skewed rdd and modify the key to add some randomness by doing (original_key, random_int) where random_int takes a value between 0 and N. Note that in this case we are NOT replicating the data in the large rdd. We are simply splitting the keys so that values associated to the same original key are now split into N buckets.
  • Finally, we perform the join between these datasets.
  • We remove the random_int from the key to have the final result of the join.
N = 100 # parameter to control level of data replication
small_rdd_transformed = small_rdd.cartesian(sc.parallelize(range(0, N))).map(lambda x: ((x[0][0], x[1]), x[0][1])).coalesce(num_parts).cache() # replicate the small rdd
skewed_large_rdd_transformed = x: ((x[0], randint(0, N-1)), x[1])).partitionBy(num_parts).cache() # add a random int to forma  new key

t0 =
result = skewed_large_rdd_transformed.leftOuterJoin(small_rdd_transformed)
print "The hashed join takes %s"%(str( - t0))
The hashed join takes 0:00:04.893182

So as you can see, using this trick, we have decreased the time necessary for the join from 1 mins 50s to 5 seconds!

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s