-
Notifications
You must be signed in to change notification settings - Fork 0
/
avl.cpp
115 lines (101 loc) · 2.34 KB
/
avl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/* AVL tree implementation
* supports the classical find/insert/erase operations of a set
*/
#include <algorithm>
using namespace std;
template<typename T>
struct AVLTree
{
struct Node
{
Node() { height = 0; left = right = this; }
Node(T v, Node *l, Node *r) { height = 1; val = v; left = l; right = r; }
void recalc() { height = max(left->height, right->height) + 1; }
int balance() { return right->height - left->height; }
T val;
int height;
Node *left, *right;
};
Node *bottom = new Node();
Node *deleted = bottom, *last = bottom, *root = bottom;
~AVLTree() { destroy(root); delete bottom; }
void destroy(Node *r)
{
if(r == bottom) return;
destroy(r->left);
destroy(r->right);
delete r;
}
Node *rotateleft(Node *r)
{
Node *t = r->right;
r->right = t->left;
t->left = r;
r->recalc();
t->recalc();
return t;
}
Node *rotateright(Node *r)
{
Node *t = r->left;
r->left = t->right;
t->right = r;
r->recalc();
t->recalc();
return t;
}
Node *rebalance(Node *r)
{
r->recalc();
if(r->balance() > 1)
{
if(r->right->balance() < 0)
r->right = rotateright(r->right);
r = rotateleft(r);
}
else if(r->balance() < -1)
{
if(r->left->balance() > 0)
r->left = rotateleft(r->left);
r = rotateright(r);
}
return r;
}
bool find(T val) { return find(val, root); }
void insert(T val) { root = insert(val, root); }
void erase(T val) { root = erase(val, root); }
bool find(T val, Node *r)
{
if(r == bottom) return false;
if(val < r->val) return find(val, r->left);
if(val > r->val) return find(val, r->right);
return true;
}
Node *insert(T val, Node *r)
{
if(r == bottom) return new Node(val, bottom, bottom);
if(val < r->val) r->left = insert(val, r->left);
if(val > r->val) r->right = insert(val, r->right);
return rebalance(r);
}
Node *erase(T val, Node *r)
{
if(r == bottom) return bottom;
last = r;
if(val < r->val) r->left = erase(val, r->left);
else deleted = r, r->right = erase(val, r->right);
if(r == last)
{
if(deleted != bottom && deleted->val == val)
{
deleted->val = last->val;
deleted = bottom;
Node *t = last->right;
delete last;
return t;
}
return r;
}
return rebalance(r);
}
};