Basic Principles
Splay tree is a specific kind of binary search trees. It maintains the size comparison properties of BSTs, but manipulates its balance in another way. Its amortized time complexity is O (log n), where its worst case is O (n), takes the shape of a chain. Splay trees have better constants than Red-Black Trees, that would act rather slow.
Sample Codes
This is the sample code of a structural splay tree. Only contains the splay operations without query operations.
#include <iostream>
#include <cstring>
using namespace std;
const int maxn = 1010;
class SplayTree
{
public:
// Structural databases
int parent[maxn], ch[maxn][2], root, ncnt;
// Data-value informational structures
int val[maxn];
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
#define par(x) parent[x]
void rotate(int p)
{
int q = par(p), y = par(q), x = (rc(q) == p);
ch[q][x] = ch[p][!x];
par(ch[q][x]) = q;
ch[p][!x] = q;
par(q) = p;
par(p) = y;
if (y) {
if (lc(y) == q)
lc(y) = p;
else // if (rc(y) == q)
rc(y) = p;
}
return ;
}
void splay(int p)
{
for (int q = 0; q = par(p); rotate(p))
if (par(q))
rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
root = p;
return ;
}
void insert(int p, int v)
{
int q = 0;
while (true) {
if (v < val[p]) q = lc(p);
else q = rc(p);
if (q == 0) break;
p = q;
}
q = ++ncnt;
val[q] = v;
lc(q) = rc(q) = 0;
par(q) = p;
if (v < val[p]) lc(p) = q;
else rc(p) = q;
splay(q);
return ;
}
void insert(int v)
{
if (root == 0) {
root = ++ncnt;
val[root] = v;
} else {
insert(root, v);
}
return ;
}
int pre(int p)
{
int q = lc(p);
while (rc(q))
q = rc(q);
return q;
}
int suc(int p)
{
int q = rc(p);
while (lc(q))
q = lc(q);
return q;
}
void debug(void)
{
for (int i = 1; i <= 8; i++) {
printf("#%d: par %d lc %d rc %d val %d\n", i, parent[i], ch[i][0], ch[i][1], val[i]);
}
}
SplayTree()
{
memset(ch, 0, sizeof(ch));
memset(parent, 0, sizeof(parent));
root = 0;
return ;
}
} sp;
This version can be used as a sequential version, where insert
operations insert objects dynamically after the given numeral position. Specifically, the splay()
operation in this version is single-rotation, instead of double-rotation, for debugging purposes, although sacrificing the time complexity. This is mainly aimed to create a correctly working version instead of a fast but erroneous one.
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
const int maxn = 10010;
class SplayTree
{
public:
int ch[maxn][2], parent[maxn], root, ncnt, n;
int size[maxn], sum[maxn];
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
#define par(x) parent[x]
void rotate(int p)
{
int q = par(p), g = par(q), x = p == rc(q);
size[q] -= size[p];
size[p] -= size[ch[p][!x]];
ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
ch[p][!x] = q, par(q) = p;
par(p) = g;
size[q] += size[ch[q][x]];
size[p] += size[q];
if (g) ch[g][rc(g) == q] = p;
return ;
}
void splay(int p)
{
while (par(p))
rotate(p);
root = p;
return ;
}
int pre(int p)
{
if (!lc(p)) {
while (p == lc(par(p))) p = par(p);
p = par(p);
} else {
p = lc(p);
while (rc(p)) p = rc(p);
} return p;
}
int suc(int p)
{
if (!rc(p)) {
while (p == rc(par(p))) p = par(p);
p = par(p);
} else {
p = rc(p);
while (lc(p)) p = lc(p);
} return p;
}
int find(int x)
{
int p = root;
while (true) {
if (x <= size[lc(p)]) {
p = lc(p);
continue;
} x -= size[lc(p)];
if (x <= 1)
return p;
x -= 1;
p = rc(p);
}
return 0;
}
int makenode(int q, int v)
{
int p = ++ncnt;
n++;
lc(p) = rc(p) = 0;
par(p) = q;
size[p] = 1;
// sum[p] = v;
return p;
}
void updatenode(int p, int v)
{
size[p]++;
// sum[p] += v;
return ;
}
void insert(int x, int v)
{
int lp = find(x), rp = suc(lp);
splay(lp);
if (rp) splay(rp);
int c = makenode(lp, v);
rc(lp) = c;
updatenode(lp, v);
if (rp) updatenode(rp, v);
return ;
}
void debug()
{
for (int i = 1; i <= n; i++) {
if (lc(i)) printf("%d %d %d\n", i, lc(i), 1);
else if (rc(i)) printf("%d %d %d\n", i, i + 100, 1);
if (rc(i)) printf("%d %d %d\n", i, rc(i), 2);
else if (lc(i)) printf("%d %d %d\n", i, i + 100, 2);
}
return ;
}
void buildtree()
{
n = ncnt = 0;
root = makenode(0, 0);
rc(root) = makenode(root, 0);
par(rc(root)) = root;
updatenode(root, 0);
return ;
}
} sp;
int main()
{
sp.buildtree();
printf("Program begun.\n");
while (true)
{
string a;
int b, c, d;
cin >> a;
if (a == "insert") {
cin >> b >> c;
sp.insert(b, c);
// } else if (a == "sum") {
// cin >> b >> c;
// printf("sum %d %d = %d\n", b, c, sp.query_sum(b, c));
} else if (a == "splay") {
cin >> b;
sp.splay(b);
} else if (a == "pre") {
cin >> b;
printf("pre %d = %d", b, sp.pre(b));
} else if (a == "suc") {
cin >> b;
printf("suc %d = %d", b, sp.suc(b));
} else if (a == "find") {
cin >> b;
printf("find %d = %d", b, sp.find(b));
} else if (a == "debug") {
sp.debug();
}
}
return 0;
}
This third version features critical updates to the splay tree. These are noted as follow:
- Enabling the splay operation of double rotations.
- Enabled splaying node under a certain node to enforce adjacency.
- Enabled sum queries but not interval addition or other advanced editing operations.
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
const int maxn = 10010;
class SplayTree
{
public:
int ch[maxn][2], parent[maxn], root, ncnt, n;
int size[maxn], sum[maxn];
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
#define par(x) parent[x]
void rotate(int p)
{
int q = par(p), g = par(q), x = p == rc(q);
size[q] -= size[p];
size[p] -= size[ch[p][!x]];
sum[q] -= sum[p];
sum[p] -= sum[ch[p][!x]];
ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
ch[p][!x] = q, par(q) = p;
par(p) = g;
size[q] += size[ch[q][x]];
size[p] += size[q];
sum[q] += sum[ch[q][x]];
sum[p] += sum[q];
if (g) ch[g][rc(g) == q] = p;
return ;
}
void splayto(int p, int t)
{
for (int q = 0; (q = par(p)) && q != t; rotate(p))
if (par(q) && par(q) != t)
rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
if (t == 0) root = p;
return ;
}
void splay(int p)
{
splayto(p, 0);
return ;
}
int pre(int p)
{
if (!lc(p)) {
while (p == lc(par(p))) p = par(p);
p = par(p);
} else {
p = lc(p);
while (rc(p)) p = rc(p);
} return p;
}
int suc(int p)
{
if (!rc(p)) {
while (p == rc(par(p))) p = par(p);
p = par(p);
} else {
p = rc(p);
while (lc(p)) p = lc(p);
} return p;
}
int find(int x)
{
int p = root;
while (true) {
if (x <= size[lc(p)]) {
p = lc(p);
continue;
} x -= size[lc(p)];
if (x <= 1)
return p;
x -= 1;
p = rc(p);
}
return 0;
}
int makenode(int q, int v)
{
int p = ++ncnt;
n++;
lc(p) = rc(p) = 0;
par(p) = q;
size[p] = 1;
sum[p] = v;
return p;
}
void updatenode(int p, int v)
{
size[p]++;
sum[p] += v;
return ;
}
void insert(int x, int v)
{
int lp = find(x), rp = suc(lp); // Operations should be guranteed that rp is valid
splayto(rp, 0);
splayto(lp, root);
int c = makenode(lp, v);
rc(lp) = c;
updatenode(lp, v);
updatenode(rp, v);
return ;
}
int query_sum(int l, int r)
{
int lp = find(l - 1), rp = find(r + 1);
splayto(rp, 0);
splayto(lp, root);
return sum[rc(lp)];
}
void debug()
{
// xmpaint grammer
for (int i = 1; i <= n; i++) {
if (lc(i)) printf("%d %d %d\n", i, lc(i), 1);
if (rc(i)) printf("%d %d %d\n", i, rc(i), 2);
}
return ;
}
void buildtree()
{
n = ncnt = 0;
root = makenode(0, 0);
rc(root) = makenode(root, 0);
par(rc(root)) = root;
updatenode(root, 0);
return ;
}
} sp;
int main()
{
sp.buildtree();
printf("Program begun.\n");
while (true)
{
string a;
int b, c, d;
cin >> a;
if (a == "insert") {
cin >> b >> c;
sp.insert(b, c);
} else if (a == "sum") {
cin >> b >> c;
printf("sum %d %d = %d\n", b, c, sp.query_sum(b, c));
} else if (a == "debug") {
sp.debug();
}
}
return 0;
}
This is the multi-operation edition, as seen in POJ-3580. Hopefully this one would work, but has not been thoroughly tested yet. Accepted version of this splay tree should be referred to the post introducing poj3580. Moreover, further bug fixes would not be introduced in this page. Improvements include:
- Stablized lazy dispatching and value updates based on stored value of the current node.
- Added removal function of nodes.
- Added interval reversal function.
- Added interval revolve function.
- Removed a vast number of unused methods.
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
const int maxn = 10010;
const int infinit = 1000000007;
class SplayTree
{
public:
int ch[maxn][2], parent[maxn], root, ncnt, n;
int size[maxn], val[maxn], sum[maxn], minn[maxn];
int lazyadd[maxn], lazyswp[maxn];
#define lc(x) ch[x][lazyswp[x]]
#define rc(x) ch[x][!lazyswp[x]]
#define par(x) parent[x]
int makenode(int q, int v)
{
int p = ++ncnt; n++;
lc(p) = rc(p) = 0;
par(p) = q;
size[p] = 1;
val[p] = sum[p] = minn[p] = v;
lazyadd[p] = lazyswp[p] = 0; // Initially they aren't lazy at all
return p;
}
void updateminn(int p)
{
minn[p] = p > 2 ? val[p] : infinit;
if (lc(p)) minn[p] = min(minn[p], minn[lc(p)]);
if (rc(p)) minn[p] = min(minn[p], minn[rc(p)]);
return ;
}
void dispatchlazyadd(int p)
{
// Separate dispatched lazy values to children
lazyadd[lc(p)] += lazyadd[p];
lazyadd[rc(p)] += lazyadd[p];
// Update children's initial values
val[lc(p)] += lazyadd[p];
val[rc(p)] += lazyadd[p];
// Update children's sums
sum[lc(p)] += size[lc(p)] * lazyadd[p];
sum[rc(p)] += size[rc(p)] * lazyadd[p];
// Update minimum queried values
minn[lc(p)] += lazyadd[p];
minn[rc(p)] += lazyadd[p];
// Finally reset lazy value
lazyadd[p] = 0;
return ;
}
bool dispatchlazyswp(int p)
{
if (!lazyswp[p]) return false;
lazyswp[lc(p)] ^= 1;
lazyswp[rc(p)] ^= 1;
swap(lc(p), rc(p));
lazyswp[p] = 0;
return true;
}
void rotate(int p)
{
int q = par(p), g = par(q), x = p == rc(q);
// Dispatching lazy values in case something goes wrong
dispatchlazyadd(q);
dispatchlazyadd(p);
if (dispatchlazyswp(q)) x ^= 1; // These make no modifications to the actual values
dispatchlazyswp(p);
// Relink connexions between nodes
ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
ch[p][!x] = q, par(q) = p;
par(p) = g;
if (g) ch[g][rc(g) == q] = p;
// Update data values
size[q] = size[lc(q)] + 1 + size[rc(q)];
size[p] = size[lc(p)] + 1 + size[rc(p)];
sum[q] = sum[lc(q)] + val[q] + sum[rc(q)];
sum[p] = sum[lc(p)] + val[p] + sum[rc(p)];
updateminn(p);
updateminn(q);
return ;
}
void splay(int p, int t)
{
for (int q = 0; (q = par(p)) && q != t; rotate(p))
if (par(q) && par(q) != t)
rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
if (t == 0) root = p;
return ;
}
int suc(int p)
{
if (!rc(p)) { while (p == rc(par(p))) p = par(p); p = par(p); }
else { p = rc(p); while (lc(p)) p = lc(p); }
return p;
}
int find(int x)
{
int p = root;
while (true) {
if (x <= size[lc(p)]) {
p = lc(p);
continue;
} x -= size[lc(p)];
if (x <= 1)
return p;
x -= 1;
p = rc(p);
}
return 0;
}
void insert(int x, int v)
{
int lp = find(x), rp = suc(lp); // Operations should be guranteed that rp is valid
splay(rp, 0);
splay(lp, root);
int c = makenode(lp, v);
rc(lp) = c;
size[lp]++, sum[lp] += v;
size[rp]++, sum[rp] += v;
updateminn(lp);
updateminn(rp);
return ;
}
void remove(int x)
{
int lp = find(x - 1), rp = suc(x);
splay(rp, 0);
splay(lp, root);
int c = rc(lp);
size[lp]--, sum[lp] -= val[c];
size[rp]--, sum[rp] -= val[c];
updateminn(lp);
updateminn(rp);
n--;
return ;
}
int query_sum(int l, int r)
{
int lp = find(l - 1), rp = find(r + 1);
splay(rp, 0);
splay(lp, root);
// Return data values
return sum[rc(lp)];
}
int query_min(int l, int r)
{
int lp = find(l - 1), rp = find(r + 1);
splay(rp, 0);
splay(lp, root);
// Return data values
return minn[rc(lp)];
}
void modify_add(int l, int r, int v)
{
int lp = find(l - 1), rp = find(r + 1);
splay(rp, 0);
splay(lp, root);
// Update data values
sum[rc(lp)] += size[rc(lp)] * v;
sum[lp] += size[rc(lp)] * v;
sum[rp] += size[rc(lp)] * v;
minn[rc(lp)] += v;
val[rc(lp)] += v;
printf("$ modify_add: it is %d who's talking about\n", rc(lp));
updateminn(lp);
updateminn(rp);
lazyadd[rc(lp)] += v;
return ;
}
void modify_swp(int l, int r)
{
int lp = find(l - 1), rp = find(r + 1);
splay(rp, 0);
splay(lp, root);
// Updating data values, which were easier
lazyswp[rc(lp)] ^= 1;
return ;
}
void buildtree()
{
n = ncnt = 0;
root = makenode(0, 0);
rc(root) = makenode(root, 0);
minn[1] = minn[2] = infinit;
par(rc(root)) = root;
size[root]++;
return ;
}
} sp;
int main()
{
sp.buildtree();
printf("Program begun.\n");
while (true)
{
string a;
int b, c, d;
cin >> a;
if (a == "insert") {
cin >> b >> c;
sp.insert(b + 1, c);
} else if (a == "delete") {
cin >> b;
sp.remove(b + 1);
} else if (a == "sum") {
cin >> b >> c;
printf("sum %d %d = %d\n", b, c, sp.query_sum(b + 1, c + 1));
} else if (a == "min") {
cin >> b >> c;
printf("min %d %d = %d\n", b, c, sp.query_min(b + 1, c + 1));
} else if (a == "add") {
cin >> b >> c >> d;
sp.modify_add(b + 1, c + 1, d);
} else if (a == "reverse") {
cin >> b >> c;
sp.modify_swp(b + 1, c + 1);
} else if (a == "revolve") {
cin >> b >> c;
d = sp.query_sum(c + 1, c + 1);
sp.insert(b, d);
}
}
return 0;
}