Basic Principles

See Splay Trees on Wikipedia.

Splay tree is a specific kind of binary search trees. It maintains the size comparison properties of BSTs, but manipulates its balance in another way. Its amortized time complexity is O (log n), where its worst case is O (n), takes the shape of a chain. Splay trees have better constants than Red-Black Trees, that would act rather slow.

Sample Codes

This is the sample code of a structural splay tree. Only contains the splay operations without query operations.


#include <iostream>
#include <cstring>

using namespace std;
const int maxn = 1010;

class SplayTree
{
public:
  // Structural databases
  int parent[maxn], ch[maxn][2], root, ncnt;
  // Data-value informational structures
  int val[maxn];
  #define lc(x) ch[x][0]
  #define rc(x) ch[x][1]
  #define par(x) parent[x]
  void rotate(int p)
  {
    int q = par(p), y = par(q), x = (rc(q) == p);
    ch[q][x] = ch[p][!x];
    par(ch[q][x]) = q;
    ch[p][!x] = q;
    par(q) = p;
    par(p) = y;
    if (y) {
      if (lc(y) == q)
        lc(y) = p;
      else // if (rc(y) == q)
        rc(y) = p;
    }
    return ;
  }
  void splay(int p)
  {
    for (int q = 0; q = par(p); rotate(p))
      if (par(q))
        rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
    root = p;
    return ;
  }
  void insert(int p, int v)
  {
    int q = 0;
    while (true) {
      if (v < val[p]) q = lc(p);
      else q = rc(p);
      if (q == 0) break;
      p = q;
    }
    q = ++ncnt;
    val[q] = v;
    lc(q) = rc(q) = 0;
    par(q) = p;
    if (v < val[p]) lc(p) = q;
    else rc(p) = q;
    splay(q);
    return ;
  }
  void insert(int v)
  {
    if (root == 0) {
      root = ++ncnt;
      val[root] = v;
    } else {
      insert(root, v);
    }
    return ;
  }
  int pre(int p)
  {
    int q = lc(p);
    while (rc(q))
      q = rc(q);
    return q;
  }
  int suc(int p)
  {
    int q = rc(p);
    while (lc(q))
      q = lc(q);
    return q;
  }
  void debug(void)
  {
    for (int i = 1; i <= 8; i++) {
      printf("#%d: par %d lc %d rc %d val %d\n", i, parent[i], ch[i][0], ch[i][1], val[i]);
    }
  }
  SplayTree()
  {
    memset(ch, 0, sizeof(ch));
    memset(parent, 0, sizeof(parent));
    root = 0;
    return ;
  }
} sp;

This version can be used as a sequential version, where insert operations insert objects dynamically after the given numeral position. Specifically, the splay() operation in this version is single-rotation, instead of double-rotation, for debugging purposes, although sacrificing the time complexity. This is mainly aimed to create a correctly working version instead of a fast but erroneous one.


#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
const int maxn = 10010;

class SplayTree
{
public:
    int ch[maxn][2], parent[maxn], root, ncnt, n;
    int size[maxn], sum[maxn];
    #define lc(x) ch[x][0]
    #define rc(x) ch[x][1]
    #define par(x) parent[x]
    void rotate(int p)
    {
        int q = par(p), g = par(q), x = p == rc(q);
        size[q] -= size[p];
        size[p] -= size[ch[p][!x]];
        ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
        ch[p][!x] = q, par(q) = p;
        par(p) = g;
        size[q] += size[ch[q][x]];
        size[p] += size[q];
        if (g) ch[g][rc(g) == q] = p;
        return ;
    }
    void splay(int p)
    {
        while (par(p))
            rotate(p);
        root = p;
        return ;
    }
    int pre(int p)
    {
        if (!lc(p)) {
            while (p == lc(par(p))) p = par(p);
            p = par(p);
        } else {
            p = lc(p);
            while (rc(p)) p = rc(p);
        } return p;
    }
    int suc(int p)
    {
        if (!rc(p)) {
            while (p == rc(par(p))) p = par(p);
            p = par(p);
        } else {
            p = rc(p);
            while (lc(p)) p = lc(p);
        } return p;
    }
    int find(int x)
    {
        int p = root;
        while (true) {
            if (x <= size[lc(p)]) {
                p = lc(p);
                continue;
            } x -= size[lc(p)];
            if (x <= 1)
                return p;
            x -= 1;
            p = rc(p);
        }
        return 0;
    }
    int makenode(int q, int v)
    {
        int p = ++ncnt;
        n++;
        lc(p) = rc(p) = 0;
        par(p) = q;
        size[p] = 1;
        // sum[p] = v;
        return p;
    }
    void updatenode(int p, int v)
    {
        size[p]++;
        // sum[p] += v;
        return ;
    }
    void insert(int x, int v)
    {
        int lp = find(x), rp = suc(lp);
        splay(lp);
        if (rp) splay(rp);
        int c = makenode(lp, v);
        rc(lp) = c;
        updatenode(lp, v);
        if (rp) updatenode(rp, v);
        return ;
    }
    void debug()
    {
        for (int i = 1; i <= n; i++) {
            if (lc(i)) printf("%d %d %d\n", i, lc(i), 1);
            else if (rc(i)) printf("%d %d %d\n", i, i + 100, 1);
            if (rc(i)) printf("%d %d %d\n", i, rc(i), 2);
            else if (lc(i)) printf("%d %d %d\n", i, i + 100, 2);
        }
        return ;
    }
    void buildtree()
    {
        n = ncnt = 0;
        root = makenode(0, 0);
        rc(root) = makenode(root, 0);
        par(rc(root)) = root;
        updatenode(root, 0);
        return ;
    }
} sp;

int main()
{
    sp.buildtree();
    printf("Program begun.\n");
    while (true)
    {
        string a;
        int b, c, d;
        cin >> a;
        if (a == "insert") {
            cin >> b >> c;
            sp.insert(b, c);
        // } else if (a == "sum") {
        //     cin >> b >> c;
        //     printf("sum %d %d = %d\n", b, c, sp.query_sum(b, c));
        } else if (a == "splay") {
            cin >> b;
            sp.splay(b);
        } else if (a == "pre") {
            cin >> b;
            printf("pre %d = %d", b, sp.pre(b));
        } else if (a == "suc") {
            cin >> b;
            printf("suc %d = %d", b, sp.suc(b));
        } else if (a == "find") {
            cin >> b;
            printf("find %d = %d", b, sp.find(b));
        } else if (a == "debug") {
            sp.debug();
        }
    }
    return 0;
}

This third version features critical updates to the splay tree. These are noted as follow:

  • Enabling the splay operation of double rotations.
  • Enabled splaying node under a certain node to enforce adjacency.
  • Enabled sum queries but not interval addition or other advanced editing operations.

#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
const int maxn = 10010;

class SplayTree
{
public:
    int ch[maxn][2], parent[maxn], root, ncnt, n;
    int size[maxn], sum[maxn];
    #define lc(x) ch[x][0]
    #define rc(x) ch[x][1]
    #define par(x) parent[x]
    void rotate(int p)
    {
        int q = par(p), g = par(q), x = p == rc(q);
        size[q] -= size[p];
        size[p] -= size[ch[p][!x]];
        sum[q] -= sum[p];
        sum[p] -= sum[ch[p][!x]];
        ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
        ch[p][!x] = q, par(q) = p;
        par(p) = g;
        size[q] += size[ch[q][x]];
        size[p] += size[q];
        sum[q] += sum[ch[q][x]];
        sum[p] += sum[q];
        if (g) ch[g][rc(g) == q] = p;
        return ;
    }
    void splayto(int p, int t)
    {
        for (int q = 0; (q = par(p)) && q != t; rotate(p))
            if (par(q) && par(q) != t)
                rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
        if (t == 0) root = p;
        return ;
    }
    void splay(int p)
    {
        splayto(p, 0);
        return ;
    }
    int pre(int p)
    {
        if (!lc(p)) {
            while (p == lc(par(p))) p = par(p);
            p = par(p);
        } else {
            p = lc(p);
            while (rc(p)) p = rc(p);
        } return p;
    }
    int suc(int p)
    {
        if (!rc(p)) {
            while (p == rc(par(p))) p = par(p);
            p = par(p);
        } else {
            p = rc(p);
            while (lc(p)) p = lc(p);
        } return p;
    }
    int find(int x)
    {
        int p = root;
        while (true) {
            if (x <= size[lc(p)]) {
                p = lc(p);
                continue;
            } x -= size[lc(p)];
            if (x <= 1)
                return p;
            x -= 1;
            p = rc(p);
        }
        return 0;
    }
    int makenode(int q, int v)
    {
        int p = ++ncnt;
        n++;
        lc(p) = rc(p) = 0;
        par(p) = q;
        size[p] = 1;
        sum[p] = v;
        return p;
    }
    void updatenode(int p, int v)
    {
        size[p]++;
        sum[p] += v;
        return ;
    }
    void insert(int x, int v)
    {
        int lp = find(x), rp = suc(lp); // Operations should be guranteed that rp is valid
        splayto(rp, 0);
        splayto(lp, root);
        int c = makenode(lp, v);
        rc(lp) = c;
        updatenode(lp, v);
        updatenode(rp, v);
        return ;
    }
    int query_sum(int l, int r)
    {
        int lp = find(l - 1), rp = find(r + 1);
        splayto(rp, 0);
        splayto(lp, root);
        return sum[rc(lp)];
    }
    void debug()
    {
        // xmpaint grammer
        for (int i = 1; i <= n; i++) {
            if (lc(i)) printf("%d %d %d\n", i, lc(i), 1);
            if (rc(i)) printf("%d %d %d\n", i, rc(i), 2);
        }
        return ;
    }
    void buildtree()
    {
        n = ncnt = 0;
        root = makenode(0, 0);
        rc(root) = makenode(root, 0);
        par(rc(root)) = root;
        updatenode(root, 0);
        return ;
    }
} sp;

int main()
{
    sp.buildtree();
    printf("Program begun.\n");
    while (true)
    {
        string a;
        int b, c, d;
        cin >> a;
        if (a == "insert") {
            cin >> b >> c;
            sp.insert(b, c);
        } else if (a == "sum") {
            cin >> b >> c;
            printf("sum %d %d = %d\n", b, c, sp.query_sum(b, c));
        } else if (a == "debug") {
            sp.debug();
        }
    }
    return 0;
}

This is the multi-operation edition, as seen in POJ-3580. Hopefully this one would work, but has not been thoroughly tested yet. Accepted version of this splay tree should be referred to the post introducing poj3580. Moreover, further bug fixes would not be introduced in this page. Improvements include:

  • Stablized lazy dispatching and value updates based on stored value of the current node.
  • Added removal function of nodes.
  • Added interval reversal function.
  • Added interval revolve function.
  • Removed a vast number of unused methods.

#include <iostream>
#include <cstdlib>
#include <cstdio>

using namespace std;
const int maxn = 10010;
const int infinit = 1000000007;

class SplayTree
{
public:
    int ch[maxn][2], parent[maxn], root, ncnt, n;
    int size[maxn], val[maxn], sum[maxn], minn[maxn];
    int lazyadd[maxn], lazyswp[maxn];
    #define lc(x) ch[x][lazyswp[x]]
    #define rc(x) ch[x][!lazyswp[x]]
    #define par(x) parent[x]
    int makenode(int q, int v)
    {
        int p = ++ncnt; n++;
        lc(p) = rc(p) = 0;
        par(p) = q;
        size[p] = 1;
        val[p] = sum[p] = minn[p] = v;
        lazyadd[p] = lazyswp[p] = 0; // Initially they aren't lazy at all
        return p;
    }
    void updateminn(int p)
    {
        minn[p] = p > 2 ? val[p] : infinit;
        if (lc(p)) minn[p] = min(minn[p], minn[lc(p)]);
        if (rc(p)) minn[p] = min(minn[p], minn[rc(p)]);
        return ;
    }
    void dispatchlazyadd(int p)
    {
        // Separate dispatched lazy values to children
        lazyadd[lc(p)] += lazyadd[p];
        lazyadd[rc(p)] += lazyadd[p];
        // Update children's initial values
        val[lc(p)] += lazyadd[p];
        val[rc(p)] += lazyadd[p];
        // Update children's sums
        sum[lc(p)] += size[lc(p)] * lazyadd[p];
        sum[rc(p)] += size[rc(p)] * lazyadd[p];
        // Update minimum queried values
        minn[lc(p)] += lazyadd[p];
        minn[rc(p)] += lazyadd[p];
        // Finally reset lazy value
        lazyadd[p] = 0;
        return ;
    }
    bool dispatchlazyswp(int p)
    {
        if (!lazyswp[p]) return false;
        lazyswp[lc(p)] ^= 1;
        lazyswp[rc(p)] ^= 1;
        swap(lc(p), rc(p));
        lazyswp[p] = 0;
        return true;
    }
    void rotate(int p)
    {
        int q = par(p), g = par(q), x = p == rc(q);
        // Dispatching lazy values in case something goes wrong
        dispatchlazyadd(q);
        dispatchlazyadd(p);
        if (dispatchlazyswp(q)) x ^= 1; // These make no modifications to the actual values
        dispatchlazyswp(p);
        // Relink connexions between nodes
        ch[q][x] = ch[p][!x], par(ch[q][x]) = q;
        ch[p][!x] = q, par(q) = p;
        par(p) = g;
        if (g) ch[g][rc(g) == q] = p;
        // Update data values
        size[q] = size[lc(q)] + 1 + size[rc(q)];
        size[p] = size[lc(p)] + 1 + size[rc(p)];
        sum[q] = sum[lc(q)] + val[q] + sum[rc(q)];
        sum[p] = sum[lc(p)] + val[p] + sum[rc(p)];
        updateminn(p);
        updateminn(q);
        return ;
    }
    void splay(int p, int t)
    {
        for (int q = 0; (q = par(p)) && q != t; rotate(p))
            if (par(q) && par(q) != t)
                rotate((p == lc(q)) == (q == lc(par(q))) ? q : p);
        if (t == 0) root = p;
        return ;
    }
    int suc(int p)
    {
        if (!rc(p)) { while (p == rc(par(p))) p = par(p); p = par(p); }
        else { p = rc(p); while (lc(p)) p = lc(p); }
        return p;
    }
    int find(int x)
    {
        int p = root;
        while (true) {
            if (x <= size[lc(p)]) {
                p = lc(p);
                continue;
            } x -= size[lc(p)];
            if (x <= 1)
                return p;
            x -= 1;
            p = rc(p);
        }
        return 0;
    }
    void insert(int x, int v)
    {
        int lp = find(x), rp = suc(lp); // Operations should be guranteed that rp is valid
        splay(rp, 0);
        splay(lp, root);
        int c = makenode(lp, v);
        rc(lp) = c;
        size[lp]++, sum[lp] += v;
        size[rp]++, sum[rp] += v;
        updateminn(lp);
        updateminn(rp);
        return ;
    }
    void remove(int x)
    {
        int lp = find(x - 1), rp = suc(x);
        splay(rp, 0);
        splay(lp, root);
        int c = rc(lp);
        size[lp]--, sum[lp] -= val[c];
        size[rp]--, sum[rp] -= val[c];
        updateminn(lp);
        updateminn(rp);
        n--;
        return ;
    }
    int query_sum(int l, int r)
    {
        int lp = find(l - 1), rp = find(r + 1);
        splay(rp, 0);
        splay(lp, root);
        // Return data values
        return sum[rc(lp)];
    }
    int query_min(int l, int r)
    {
        int lp = find(l - 1), rp = find(r + 1);
        splay(rp, 0);
        splay(lp, root);
        // Return data values
        return minn[rc(lp)];
    }
    void modify_add(int l, int r, int v)
    {
        int lp = find(l - 1), rp = find(r + 1);
        splay(rp, 0);
        splay(lp, root);
        // Update data values
        sum[rc(lp)] += size[rc(lp)] * v;
        sum[lp] += size[rc(lp)] * v;
        sum[rp] += size[rc(lp)] * v;
        minn[rc(lp)] += v;
        val[rc(lp)] += v;
        printf("$ modify_add: it is %d who's talking about\n", rc(lp));
        updateminn(lp);
        updateminn(rp);
        lazyadd[rc(lp)] += v;
        return ;
    }
    void modify_swp(int l, int r)
    {
        int lp = find(l - 1), rp = find(r + 1);
        splay(rp, 0);
        splay(lp, root);
        // Updating data values, which were easier
        lazyswp[rc(lp)] ^= 1;
        return ;
    }
    void buildtree()
    {
        n = ncnt = 0;
        root = makenode(0, 0);
        rc(root) = makenode(root, 0);
        minn[1] = minn[2] = infinit;
        par(rc(root)) = root;
        size[root]++;
        return ;
    }
} sp;

int main()
{
    sp.buildtree();
    printf("Program begun.\n");
    while (true)
    {
        string a;
        int b, c, d;
        cin >> a;
        if (a == "insert") {
            cin >> b >> c;
            sp.insert(b + 1, c);
        } else if (a == "delete") {
            cin >> b;
            sp.remove(b + 1);
        } else if (a == "sum") {
            cin >> b >> c;
            printf("sum %d %d = %d\n", b, c, sp.query_sum(b + 1, c + 1));
        } else if (a == "min") {
            cin >> b >> c;
            printf("min %d %d = %d\n", b, c, sp.query_min(b + 1, c + 1));
        } else if (a == "add") {
            cin >> b >> c >> d;
            sp.modify_add(b + 1, c + 1, d);
        } else if (a == "reverse") {
            cin >> b >> c;
            sp.modify_swp(b + 1, c + 1);
        } else if (a == "revolve") {
            cin >> b >> c;
            d = sp.query_sum(c + 1, c + 1);
            sp.insert(b, d);
        }
    }
    return 0;
}