Integrating Spark with scikit-learn, visualizing eigenvectors, and fun!

Three topics in this post, to make up for the long hiatus!


Apache Spark’s MLlib has built-in support for many machine learning algorithms, but not everything of course. But one can nicely integrate scikit-learn (sklearn) functions to work inside of Spark, distributedly, which makes things very efficient. That’s what I’m going to be talking about here.

As a practical example, let’s consider k-Nearest-Neighbors (k-NN). Spark’s MLlib doesn’t have built-in support for this, but scikit-learn does.

So let’s talk about sklearn for a minute. If you have a large number of points, say a million or more, and you want to obtain nearest neighbors for all of them (as may be the case with a k-NN-based recommender system), sklearn’s NearestNeighbors on a single machine can be hard to work with. The fit() method isn’t what takes a long time, it’s subsequently producing the results for the large number of queries with kneighbors() that is expensive:

In the most straightforward deployment, if you try to send kneighbors() all point vectors in a single large matrix and ask it to come up with nearest neighbors for all of them in one fell swoop, it quickly exhausts the RAM and brings the machine to a crawl. Alternatively, the batch iteration method that I mentioned before is a good solution: after performing the initial fit, you can break the large matrix into chunks and obtain their neighbors chunk by chunk. This eases memory consumption, but can take a long time.

There are of course approximate nearest-neighbor implementations such as Spotify’s Annoy. In my use case, Annoy actually did worse than sklearn’s exact neighbors, because Annoy does not have built-in support for matrices: if you want to evaluate nearest neighbors for n query points, you have to loop through each of your n queries one at a time, whereas sklearn’s k-NN implementation can take in a single matrix containing many query points and return nearest neighbors for all of them at a blow, relatively quickly. Your mileage may vary. I’ll talk about Annoy again a little later.

To summarize the problem:

  • sklearn has good support for k-NN; Spark doesn’t.
  • sklearn’s k-NN fit() isn’t a problem
  • sklearn’s k-NN kneighbors() is a computational bottleneck for large data sets; is a good candidate for parallelization

This is where Spark comes in. All we have to do is insert kneighbors() into a Spark map function after setting the stage for it. This is especially neat if you’re already working in Spark and/or if your data is already in HDFS to begin with, as is commonly the case.

Below is a simplified Python (PySpark) code snippet to make this approach clear:

# Imports
from pyspark import SparkConf, SparkContext
from sklearn.neighbors import NearestNeighbors

# Let's say we already have a Spark object containing
# all our vectors, called myvecs

# Create kNN tree locally, and broadcast
myvecscollected = myvecs.collect()
knnobj = NearestNeighbors().fit(myvecscollected)
bc_knnobj = sc.broadcast(knnobj)

# Get neighbors for each point, distributedly
results = x: bc_knnobj.value.kneighbors(x))

Boom! That’s all you need. The key point in the above code is that we were able to pass sklearn’s NearestNeighbors’ kneighbors() method inside of Spark’s map(), which means that it can be parallel-y and nicely handled by Spark.

(You can do the same thing using Annoy instead of sklearn, except that instead of broadcasting the Annoy object to workers, you need to serialize it to a file and distribute the file to workers instead. This code shows you how.)

In my use case, harnessing Spark to distribute my sklearn code brought my runtime down from hours to minutes!

Update: between the time I first considered this problem and now, there has also emerged a Spark package for distributing sklearn functionality over Spark, as well as a more comprehensive integration called sparkit-learn. So there are several solutions available now. I still like the approach shown above for its simplicity, and for not requiring any extraneous code.


A beautiful interactive visualization of eigenvectors, courtesy of the wonderful folks at Setosa.

The thing that I love about this viz is that it doesn’t just show how eigenvectors are computed, it gives you an intuition for what they mean.


Lastly, and just for fun: Is it Pokemon or Big Data? ☺



  1. Good idea, Apu. I like it. I am wondering, if there is a way to distribute KNN fit(). It works smooth, however if we have a dataset with 100K+ samples, it takes significant time to compute the model, since it is not being parallelized.


    1. Unfortunately, I don’t know of a way to distribute sklearn’s k-NN fit() method. Does playing with the leaf_size option help reduce runtime? I didn’t have trouble with 100K+ samples, but it might have to do with the sparsity/dimensionality of the data.

      I was hoping that the latest version of Spark (2.0) might have built-in support for nearest neighbors (either exact or approximate), but sadly that hasn’t happened yet, though there is a JIRA item for it ( In the mean time, you may need to hack together your own solution for your use case. Good luck!

      Liked by 1 person

    1. In that case you won’t be able to leverage sklearn’s NearestNeighbors with Spark. (My goal is to – where possible – avoid having to write my own custom code, and to leverage tools that already exist. But this is, of course, not possible in all cases.)


  2. I did not understand this step
    # Let’s say we already have a Spark object containing
    # all our vectors, called myvecs

    I am creating a list of vectors and sending this list to fit(), then i am getting error as
    “array = array.astype(np.float64)
    TypeError: float() argument must be a string or a number”

    Can you please clarify this or can you please mail me a sample code of your approach.
    mail id:



    1. I think you may need to convert your list of vectors to an array in order to run the .fit() method. This would look like:

      import numpy as np
      myvecscollected = np.vstack(tuple(myveclist))
      knnobj = NearestNeighbors().fit(myvecscollected)

      Here of course I’m assuming that myveclist is a local object. If your list of vectors is in Spark you’ll need to run .collect() first.


  3. Hi, I followed your method and I’m really curious how you got that significant speedup. I think I’m doing something wrong because, in my case, this is really slow compared to the single-machine version.

    So, I called fit() on a vector with 30M elements and if I want to query kneighbors() on 1M elements it takes 50seconds locally (not using Spark at all), whereas on more machines using Spark’s map() it takes forever.

    I also tried with mapPartitions() and send the whole query vector to kneighbors() in order to make use of the n_jobs=-1 parameter I passed to the NearestNeighbors object.

    Do you have any suggestions, ideas where I’m going wrong? Thanks


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s