Counting inversions in O(n log n) is legendary algorithmic question with infamous divide-and-conquer solution. However, did you know that there exists a family of algorithms to solve it with same asymptotics? Let’s find out!
Inversions
Let a be our array and i, j indexes.
By definition number of inversions is number of pairs (i, j) for which following predicate is true: a[i] > a[j] and i < j.
Simple quadratic algorithm can be easily constructed:
# O(n ^ 2)
def inversions_base(a):
n, cnt = len(a), 0
for i in range(0, n):
for j in range(0, n):
if a[i] > a[j] and i < j:
cnt += 1
return cnt
How to progress from it? First we need to reformulate counting task in terms of flag array.
# more complicated, but still O(n ^ 2)
def inversions_flag_array(a):
n, cnt, M = len(a), 0, max(a)
T = [0] * (M + 1)
for i in range(0, n):
T[a[i]] += 1 # note, exactly +1
for j in range(a[i] + 1, M+1):
cnt += T[j]
return cnt
What is going on here? We iterate over array elements one by one in order and use auxiliary array T of flags. Update T[k] += 1 if we meet k = a[i] in array (for arrays with distinct elements T[k]=1 is enough). Then we can count number of occurrences for elements which are larger than a[i] but located before easily by watching T. Since T is ordinary array we can simply iterate from a[i]+1 to max(a). That still results in quadratic solution, however we can make improvements from here.
Note that inner part of main cycle is a sum query. And as we know (don’t we?) that such queries on array interval can be made more efficient with help of data structures. Fenwick tree here is certainly a way to go. Let’s do it, start from class for our tree.
One can actually substitute every appropriate data structure instead of Fenwick tree here, for example, interval tree.
Fenwick tree
First discuss Fenwick tree a little bit.
Let’s start from practical definition. From this point Fenwick tree is just an array which stores specific interval sums of given array, exactly:
T[i] = a[f(i)] + … + a[i], where f(i) = i AND (i + 1).
It can be already build directly from this definition in O(n ^ 2) (however, it’s easy to do this in O(n log n) and even in O(n)). The most crucial part here how to compute any prefix sum efficiently using it.
To see why it is actually a tree let’s look at the picture:
It is called interrogation tree and basically gives us a hint how to compute prefix sum from tree nodes. We start from our index and then toggle last significant bits and make one step back from this number to get into another precomputed interval. In other words, parent nodes in this tree is just numbers f(i) - 1 what we get from child index i.
For example, if we get k then we compute f(k), so a[f(k)] + … + a[k] is in T[k] by definition of T. Then we can proceed with f(k)-1 to get lower part of sum a[f(f(k)-1)] + … + a[f(k)-1], until we will finally move down to zero index. Obviously, f(k) - 1 < k and someday this algorithm will stop since we started it from finite number k.
And actually this will happen in logarithmic amount of steps, resulting in O(log n) complexity for sum computation. The argument here that every time you do x AND (x + 1) you toggle all k 1’s at the end of number binary representation and therefore get at least k + 1 zero at the end. Then after substracting one you get at least k + 1 of 1’s instead of zeros. Run this code to check:
nums = list(range(0, 21))
steps = []
for num in nums:
cnt, step = num, 0
while cnt > 0:
step += 1
cnt = (cnt & (cnt + 1)) - 1
steps.append(step)
print(steps) # that's OEIS A000120 with a shift
Let’s look at an example. Suppose, you want to find prefix sum of first 15 elements, then following indexes will be affected in prefix_sum() (also, look at the interrogation tree above):
- 14 (01110);
- 14 AND (14 + 1) - 1 = 13 (01101);
- 13 AND (13 + 1) - 1 = 11 (01011);
- 11 AND (11 + 1) - 1 = 7 (00111);
So a[0] + … + a[14] = T[7] + T[11] + T[13] + T[14].
Ok, now let’s find out how to modify nodes of the tree if we’ve just changed one element in our array by adding some number d? We should change only intervals which include a[i]: T[j]: f(j) <= i <= j. What are these j?
Firstly, it should start from j = i, then it can be proved that next element in this sequence is given by expression j | (j + 1).
Not going too far into details this transformation add one more 1 into binary representation of the number, from right to left. Doing this until we get index larger then size(T) results in log(size(T)) complexity.
Lets build table, where each column shows which array indexes are stored for every tree index.
Let’s look at an example. Suppose, you want to modify a[8]. Then, following indexes will be affected by modify():
- 8 (01000);
- 8 OR (8 + 1) = 9 (01001);
- 9 OR (9 + 1) = 11 (01011);
- 11 OR (11 + 1) = 15 (01111);
These indexes coincide with intersection shown on the table above.
Finally, here’s the code:
class FenwickTree:
def __init__(self, n):
self.t = [0] * n # long empty list
self.N = n
def modify(self, i, d):
while i < self.N:
self.t[i] += d
i = i | (i + 1)
def prefix_sum(self, i):
result = 0
while i >= 0:
result += self.t[i]
i = (i & (i + 1)) - 1
return result
def query(self, i, j):
return self.prefix_sum(j) - self.prefix_sum(i - 1)
Back to flag arrays
Now flag array code snippet transforms into:
# finally, O(n log n)
def inversions_fenwick(a):
n, cnt, M = len(a), 0, 20
T = FenwickTree(M+1)
for i in range(0, n):
T.modify(a[i], 1)
cnt += T.query(a[i] + 1, M)
return cnt
Last, but not least: do not forget to compress coordinates in array (change elements to indexes in O(n log n)) to make the solution efficient. And that’s it. This simple idea has caught my attention recently, since it is not well known as classic mergesort-like algorithm.
Instead of conclusion
In this small article we investigated nice solution of counting inversions in array problem. You can find full code here.