Login

Factual Blog /

Fast Indirect Sorting in Java

Fast indirect sorting in Java

I was recently writing some performance-sensitive code in which I had a double array of distances (one per element), and I wanted to get a list of elements sorted by distance:

double[]  distances = { d1, d2, ..., dN };
Element[] elements  = { e1, e2, ..., eN };
// do something here to sort elements[] by distances[]

Java provides Arrays.sort for direct sorting; that is, it’s easy to ask it to sort distances or elements by its natural ordering. But in this situation, the two arrays are tied together only by indexes, which would require a comparator to maintain a reverse lookup from Element to either its index, or to its distance. That’s a lot of overhead – particularly because the map would require generic boxing of either type of value.

Luckily there’s an interesting way to solve this problem that meets the following requirements:

  1. No extra memory is allocated, aside from the two arrays above, each of which is allowed to be clobbered
  2. We do not write any sorting code; we just use standard APIs
  3. We end up being able to access the distance-sorted Elements each in constant time

I should also mention that the resulting ordering isn’t exact, but it is very close.

Before I go into how I solved it (which took a while to think of), you should see what you come up with. It’s a fun problem.

The solution

First in code:

int largest = distances.length - 1;
long mask = -1l;
while ((~mask & largest) != largest) mask <<= 1;
for (int i = 0; i < distances.length; ++i)
  distances[i] = Double.longBitsToDouble(
                   Double.doubleToLongBits(distances[i]) & mask | i);
Arrays.sort(distances);

Now distances[] is sorted such that the distance of elements[(int) (Double.doubleToLongBits(distances[i]) & ~mask)] is ascending for ascending values of i.

How it works

Each distance is encoded as a double-precision float, which internally looks like this:

  +--- sign
  |
  | +- exponent
  | |
  S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM
              |---------------------52 mantissa bits-------------------|

The mantissa for all but the smallest numbers is normal, meaning that it’s interpreted as though there were a leading 1 in a 53rd bit. This puts an upper bound on the significance of low-order mantissa bits, which is what we need for the code above to work.

Depending on how the distances are distributed, we can make a probabilistic argument about how much the ordering will change as we lose precision in the mantissa; specifically, suppose we’ve got two distances a and b and we lose 24 bits:

A = S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM
B = S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM
                |-----------------------------| |------------------------|
                      keeping these bits               losing these

The probability of changing the ordering between these two points is the same as the probability that the bits we’re keeping are all identical between them. Benford’s law converges rapidly to a uniform distribution for subsequent digits and the leading 1 is implied, so in practical terms P(reordering) is very nearly 2-k, where k is the number of bits being kept.

If we can lose some precision without causing problems (which for my use case was true), then we can arbitrarily reassign low-order mantissa bits to store information. In this case I’m storing the original array index for each distance in its low-order bits. Here’s the code above, piece by piece:

// The largest index is distances.length - 1; when we encode stuff, we need to
// reserve enough bits to represent this quantity.
int largest = distances.length - 1;

// Now construct a mask that selects the bits we're keeping in the double. -1
// has all bits set, and each left-shift clears one from the right. The test
// "(~mask & largest) != largest" will fail and end the loop when the bits we're
// keeping are disjoint from the ones we'll use to store the indexes.
long mask = -1l;
while ((~mask & largest) != largest) mask <<= 1;

Then we tag each distance this way (here I’m assuming we’ve got between 219 and 220 elements, so we reserve 20 bits):

  d[i]  S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMMMMM

        |-----------------original bits----------------||-----tag space------|
& mask  1 111111111 1111 11111111 11111111 11111111 11110000 00000000 00000000
| i     0 000000000 0000 00000000 00000000 00000000 0000IIII IIIIIIII IIIIIIII

=       S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMIIII IIIIIIII IIIIIIII
// Store the original index into each double using the strategy described above
for (int i = 0; i < distances.length; ++i)
  distances[i] = Double.longBitsToDouble(
                   Double.doubleToLongBits(distances[i]) & mask | i);

At this point Arrays.sort() will be none the wiser and will sort the array normally (and, importantly, very quickly).

Now we can read the tags back to recover the ordering, which looks like this:

  mask  1 111111111 1111 11111111 11111111 11111111 11110000 00000000 00000000
 ~mask  0 000000000 0000 00000000 00000000 00000000 00001111 11111111 11111111

  d[i]  S EEEEEEEEE MMMM MMMMMMMM MMMMMMMM MMMMMMMM MMMMIIII IIIIIIII IIIIIIII
& ~mask 0 000000000 0000 00000000 00000000 00000000 0000IIII IIIIIIII IIIIIIII
(int)                                      00000000 0000IIII IIIIIIII IIIIIIII
// Do the mask inversion up front rather than inside each iteration; now it
// selects the tag bits.
mask = ~mask;

// Collect the elements using the original positions as an indirect index.
List<Element> sorted = new ArrayList<>(elements.length);
for (double d : distances)
  sorted.add(elements[(int) (Double.doubleToLongBits(d) & mask)]);

Other languages

You can use this hack in any language to similar effect, though you lose most of the performance advantages if the language doesn’t have bitwise access to doubles. Even a well-optimized sorting function that doesn’t have the Java indirection problem will benefit if you can store the data in the floats directly, since the array will be smaller in memory and you’re doing O(n log n) element-copy operations.

It is possible to simulate bitwise access using floating point arithmetic and int casting (which flotsam does in Javascript), but it requires some care – particularly in this case, when any rounding error will cause data loss within the indexes. It’s also a lot slower than the bitwise solution above, possibly enough to outweigh any performance benefits in the sorting logic itself.

When this kind of thing doesn’t work

Distances are ideal for this type of hack because they tend to be spread over a wide range of magnitudes, and even when they aren’t, you don’t tend to care much whether the points are exactly ordered (i.e. the parts-per-trillion error we’re introducing doesn’t really pose a problem most of the time). Not all distributions are so robust to bit-twiddling, though. In particular, if you had a case where the variance were many orders of magnitude smaller than the average – e.g. 1,000,000,000 ±0.003 – then there’s a good chance the small bits would matter. It’s important to figure out the probability of a false reordering before losing bits of precision.

Enjoy this read? Factual might be the place for you!
See Openings