This blog entry’s going to show you how to implement K-means clustering a bit more efficiently than the naive algorithm with the help of a little algebra and sparse vector manipulation. The current implementation of k-means clustering in LingPipe is woefully slow:
- LingPipe Javadoc:
We’re working on a medium-size clustering problem for a customer with a few hundred thousand short-ish text messages being clustered. It was taking roughly forever (OK, several hours/epoch).
I knew it was slow when I was writing. You can tell by the apologetic javadoc implementation note:
Implementation Note: The current implementation is an inefficient brute-force approach that computes the distance between every element and every centroid on every iteration. These distances are computed between maps, which itself is not very efficient. In the future, we may substitute a more efficient version of k-means implemented to this same interface.
More efficient implementations could use kd-trees to index the points, could cache scores of input feature vectors against clusters, and would only compute differential updates when elements change cluster. Threading could also be used to compute the distances.
All true, but there’s a simpler approach than kd-trees or multiple threads or caching that I just implemented and now the epochs are taking roughly no time at all (OK, about a minute).
K-means clustering is very simple. You start with a set of items and a feature extractor to convert them to feature vectors. These feature vectors are maps from strings to numbers. In the first step, items are assigned to clusters somehow, for instance randomly (though there are better ways to do this for some purposes). Then on for each iteration until a maximum number or convergence, the mean (aka centroid) of each cluster is computed by averaging all the vectors (centroids are calculated dimensionwise), and each item’s Euclidean distance to each centroid is computed. After all these computations, each item is reassigned to the closest centroid at the end of an epoch.
Typically, when I do heavy lifting with features, I extract the features once and then re-use the vectors. I did that with k-means, but didn’t consider the downstream computations closely enough. When you have a 100,000-feature centroid vector and a 100-feature item vector, it’s rather inefficient to compute Euclidean distance, which is defined as:
Of course, we don’t care about distance per se, so we can save the square root operation and work with squared distances
There are a couple of problems. Hash maps in Java involving numbers are wasteful in size and time to convert back and forth to primitives for arithmetic. They’re also slow to iterate (even linked hash maps). That’s an obvious bottleneck. But the real bottleneck is iterating at all. Remember, the fastest code is the code that doesn’t execute at all (as I noted in a comment about point 6 of John Langford’s excellent blog entry on machine learning code optimization).
The key to optimization is noting that most of the
elt[i] values are zero. So what we do is compute the squared length of each centroid and store it at the beginning of each epoch:
The second part of this equation is key, as it shows how the length of the centroid is related to distances to dimensions with zero values. Here’s the formula:
Now note that when
elt[i]=0, the terms inside the sum cancel. So we keep going with:
This simple algebra means that after computing the length of each centroid, each distance computation only requires a number of operations proportional to the number of non-zero elements in the item’s feature vector.
Combining this with more reasonable sparse vector representations (parallel
int index and
double value arrays), and reasonable dense vector representations (i.e.
double), K-means is now hundreds of times faster for the kinds of problems we care about: large sets of sparse feature vectors. I know I can squeeze out at least another factor of 2 or 3 on my dual quad-core machine by multi-threading; k-means and other EM-like algorithms are trivial to parallelize because all the heavy computations in the inner loops are independent.
The next obvious speedup would be to improve memory locality (Langford’s tip number 3, and the heart of a lot of the speedups in Jon Bentley’s wonderful book Programming Pearls, the pearls in which were extracted from his Dr. Dobbs columns).
As it is, I iterate over the items, then for each item, I iterate over the cluster centroids computing distances, which involves iterating over the non-zero values in the item’s feature vector. It’d be much better for memory locality to do this by iterating over the non-zero values in the item’s feature vector and pulling back an array indexed by cluster. That way, there’d be a number of non-memory-continguous lookups equal to the number of non-zero features, and everything else would be local.