Skip to content

Commit

Permalink
Add radix heap
Browse files Browse the repository at this point in the history
  • Loading branch information
indy256 committed Aug 30, 2022
1 parent 4af80c8 commit 6cbe563
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 88 deletions.
2 changes: 1 addition & 1 deletion cpp/graphs/flows/min_cost_flow_dijkstra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct min_cost_flow {
if (dist[v] == numeric_limits<int>::max())
h.add(v, nprio);
else
h.changePriority(v, nprio);
h.change_value(v, nprio);
dist[v] = nprio;
prevnode[v] = u;
prevedge[v] = i;
Expand Down
4 changes: 2 additions & 2 deletions cpp/graphs/shortestpaths/dijkstra_custom_heap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void add(int id, int prio) {
move_up(hsize++);
}

void increase_priority(int id, int prio) {
void decrease_value(int id, int prio) {
int pos = id2Pos[id];
h[pos] = prio;
move_up(pos);
Expand Down Expand Up @@ -101,7 +101,7 @@ void dijkstra(int s) {
if (prio[v] == numeric_limits<int>::max())
add(v, nprio);
else
increase_priority(v, nprio);
decrease_value(v, nprio);
prio[v] = nprio;
pred[v] = u;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/structures/binary_heap_indexed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ int main() {
h.add(0, 50);
h.add(1, 30);
h.add(2, 40);
h.changePriority(0, 20);
h.change_value(0, 20);
h.remove(1);
while (h.size) {
cout << h.remove_min() << endl;
Expand Down
4 changes: 2 additions & 2 deletions cpp/structures/binary_heap_indexed.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ struct binary_heap_indexed {
int pos = id2Pos[id];
pos2Id[pos] = pos2Id[--size];
id2Pos[pos2Id[pos]] = pos;
changePriority(pos2Id[pos], heap[size]);
change_value(pos2Id[pos], heap[size]);
}

void changePriority(int id, T value) {
void change_value(int id, T value) {
int pos = id2Pos[id];
if (heap[pos] < value) {
heap[pos] = value;
Expand Down
4 changes: 3 additions & 1 deletion java/graphs/flows/MinCostFlowDijkstra.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.*;
import java.util.stream.Stream;
import structures.BinaryHeapIndexed;
import structures.RadixHeapIndexed;

// https://cp-algorithms.com/graph/min_cost_flow.html in O(E * V + min(E * logV * FLOW, V^2 * FLOW))
// negative-cost edges are allowed
Expand Down Expand Up @@ -61,6 +62,7 @@ void bellmanFord(int s, int[] dist) {
void dijkstraSparse(
int s, int t, int[] pot, int[] dist, boolean[] finished, int[] curflow, int[] prevnode, int[] prevedge) {
BinaryHeapIndexed h = new BinaryHeapIndexed(graph.length);
// RadixHeapIndexed h = new RadixHeapIndexed(graph.length);
h.add(s, 0);
Arrays.fill(dist, Integer.MAX_VALUE);
dist[s] = 0;
Expand All @@ -79,7 +81,7 @@ void dijkstraSparse(
if (dist[v] == Integer.MAX_VALUE)
h.add(v, nprio);
else
h.changePriority(v, nprio);
h.changeValue(v, nprio);
dist[v] = nprio;
prevnode[v] = u;
prevedge[v] = i;
Expand Down
86 changes: 9 additions & 77 deletions java/graphs/shortestpaths/DijkstraCustomHeap.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@

import java.util.*;
import java.util.stream.Stream;
import structures.BinaryHeapIndexed;

// https://en.wikipedia.org/wiki/Dijkstra's_algorithm
public class DijkstraCustomHeap {
// calculate shortest paths in O(E*log(V)) time and O(V) memory
public static void shortestPaths(List<Edge>[] graph, int s, long[] prio, int[] pred) {
public static void shortestPaths(List<Edge>[] graph, int s, int[] prio, int[] pred) {
Arrays.fill(pred, -1);
Arrays.fill(prio, Long.MAX_VALUE);
Arrays.fill(prio, Integer.MAX_VALUE);
prio[s] = 0;
BinaryHeap h = new BinaryHeap(graph.length);
BinaryHeapIndexed h = new BinaryHeapIndexed(graph.length);
h.add(s, 0);
while (h.size != 0) {
int u = h.remove();
int u = h.removeMin();
for (Edge e : graph[u]) {
int v = e.t;
long nprio = prio[u] + e.cost;
int nprio = prio[u] + e.cost;
if (prio[v] > nprio) {
if (prio[v] == Long.MAX_VALUE)
if (prio[v] == Integer.MAX_VALUE)
h.add(v, nprio);
else
h.increasePriority(v, nprio);
h.changeValue(v, nprio);
prio[v] = nprio;
pred[v] = u;
}
Expand All @@ -39,75 +40,6 @@ public Edge(int t, int cost) {
}
}

static class BinaryHeap {
long[] heap;
int[] pos2Id;
int[] id2Pos;
int size;

public BinaryHeap(int n) {
heap = new long[n];
pos2Id = new int[n];
id2Pos = new int[n];
}

public int remove() {
int removedId = pos2Id[0];
heap[0] = heap[--size];
pos2Id[0] = pos2Id[size];
id2Pos[pos2Id[0]] = 0;
down(0);
return removedId;
}

public void add(int id, long value) {
heap[size] = value;
pos2Id[size] = id;
id2Pos[id] = size;
up(size++);
}

public void increasePriority(int id, long value) {
heap[id2Pos[id]] = value;
up(id2Pos[id]);
}

void up(int pos) {
while (pos > 0) {
int parent = (pos - 1) / 2;
if (heap[pos] >= heap[parent])
break;
swap(pos, parent);
pos = parent;
}
}

void down(int pos) {
while (true) {
int child = 2 * pos + 1;
if (child >= size)
break;
if (child + 1 < size && heap[child + 1] < heap[child])
++child;
if (heap[pos] <= heap[child])
break;
swap(pos, child);
pos = child;
}
}

void swap(int i, int j) {
long tt = heap[i];
heap[i] = heap[j];
heap[j] = tt;
int t = pos2Id[i];
pos2Id[i] = pos2Id[j];
pos2Id[j] = t;
id2Pos[pos2Id[i]] = i;
id2Pos[pos2Id[j]] = j;
}
}

// Usage example
public static void main(String[] args) {
int[][] cost = {{0, 3, 2}, {0, 0, -2}, {0, 0, 0}};
Expand All @@ -120,7 +52,7 @@ public static void main(String[] args) {
}
}
}
long[] dist = new long[n];
int[] dist = new int[n];
int[] pred = new int[n];
shortestPaths(edges, 0, dist, pred);
System.out.println(0 == dist[0]);
Expand Down
8 changes: 4 additions & 4 deletions java/structures/BinaryHeapIndexed.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ public void remove(int id) {
int pos = id2Pos[id];
pos2Id[pos] = pos2Id[--size];
id2Pos[pos2Id[pos]] = pos;
changePriority(pos2Id[pos], heap[size]);
changeValue(pos2Id[pos], heap[size]);
}

public void changePriority(int id, int value) {
public void changeValue(int id, int value) {
int pos = id2Pos[id];
if (heap[pos] < value) {
heap[pos] = value;
Expand Down Expand Up @@ -90,8 +90,8 @@ public static void main(String[] args) {
heap.add(1, 5);
heap.add(2, 2);

heap.changePriority(1, 3);
heap.changePriority(2, 6);
heap.changeValue(1, 3);
heap.changeValue(2, 6);
heap.remove(0);

// print elements in sorted order
Expand Down
90 changes: 90 additions & 0 deletions java/structures/RadixHeapIndexed.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package structures;

import java.util.Arrays;

public class RadixHeapIndexed {
static final int BUCKETS = 32 + 1;
int[][] values = new int[BUCKETS][1];
int[][] ids = new int[BUCKETS][1];
int[] len = new int[BUCKETS];
int[] minValues;
int last;
int ptr;
public int size;

public RadixHeapIndexed(int n) {
this.minValues = new int[n];
Arrays.fill(minValues, Integer.MAX_VALUE);
}

static int lg(int x) {
return 32 - Integer.numberOfLeadingZeros(x);
}

public void add(int id, int value) {
addItem(id, value);
++size;
}

void addItem(int id, int value) {
int bucket = lg(value ^ last);
ensureCapacity(bucket);
values[bucket][len[bucket]] = value;
ids[bucket][len[bucket]] = id;
++len[bucket];
minValues[id] = value;
}

void ensureCapacity(int bucket) {
if (values[bucket].length == len[bucket]) {
values[bucket] = Arrays.copyOf(values[bucket], len[bucket] * 2);
ids[bucket] = Arrays.copyOf(ids[bucket], len[bucket] * 2);
}
}

public void changeValue(int id, int value) {
addItem(id, value);
}

public int removeMin() {
pull();
--size;
return ids[0][ptr++];
}

void pull() {
if (ptr < len[0])
return;
len[0] = 0;
ptr = 0;
int i = 1;
last = Integer.MAX_VALUE;
do {
while (len[i] == 0) ++i;
for (int j = 0; j < len[i]; j++) {
if (values[i][j] == minValues[ids[i][j]]) {
last = Math.min(last, values[i][j]);
}
}
for (int j = 0; j < len[i]; j++) {
if (values[i][j] == minValues[ids[i][j]]) {
addItem(ids[i][j], values[i][j]);
}
}
len[i] = 0;
} while (last == Integer.MAX_VALUE);
}

public static void main(String[] args) {
RadixHeapIndexed h = new RadixHeapIndexed(10);
h.add(0, 10);
h.add(1, 5);
System.out.println(h.removeMin());
h.add(2, 9);
System.out.println(h.removeMin());
h.add(3, 10);
h.changeValue(0, 10);
System.out.println(h.removeMin());
System.out.println(h.removeMin());
}
}

0 comments on commit 6cbe563

Please sign in to comment.