-
Notifications
You must be signed in to change notification settings - Fork 0
/
treap.cpp
100 lines (89 loc) · 2.12 KB
/
treap.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/* Treap implementation
* supports find/insert/erase as well as split/merge
*/
#include <cstdlib>
#include <utility>
using namespace std;
template<typename T>
struct Treap
{
struct Node
{
Node() { left = right = this; }
Node(T v, int p, Node *l, Node *r) { prior = p; val = v; left = l; right = r; }
T val;
int prior;
Node *left, *right;
};
Node *bottom = new Node();
Node *root = bottom;
Treap() { srand(42); }
~Treap() { destroy(root); delete bottom; }
void destroy(Node *r)
{
if(r == bottom) return;
destroy(r->left);
destroy(r->right);
delete r;
}
pair<Node *, Node *> split(Node *r, T key)
{
if(r == bottom) return make_pair(bottom, bottom);
if(key <= r->val)
{
auto t = split(r->left, key);
r->left = t.second;
return make_pair(t.first, r);
}
else
{
auto t = split(r->right, key);
r->right = t.first;
return make_pair(r, t.second);
}
}
Node *merge(Node *a, Node *b)
{
if(a == bottom || b == bottom) return (a == bottom) ? b : a;
if(a->prior > b->prior)
{
a->right = merge(a->right, b);
return a;
}
else
{
b->left = merge(a, b->left);
return b;
}
}
bool find(T val) { return find(val, root); }
void insert(T val) { root = insert(val, rand(), root); }
void erase(T val) { root = erase(val, root); }
bool find(T val, Node *r)
{
if(r == bottom) return false;
if(val < r->val) return find(val, r->left);
if(val > r->val) return find(val, r->right);
return true;
}
Node *insert(T val, int prior, Node *r)
{
if(r == bottom) return new Node(val, prior, bottom, bottom);
if(prior > r->prior)
{
auto t = split(r, val);
return new Node(val, prior, t.first, t.second);
}
if(val < r->val) r->left = insert(val, prior, r->left);
if(val > r->val) r->right = insert(val, prior, r->right);
return r;
}
Node *erase(T val, Node *r)
{
if(r == bottom) return bottom;
if(val < r->val) r->left = erase(val, r->left);
else if(val > r->val) r->right = erase(val, r->right);
else return merge(r->left, r->right);
return r;
}
};