红黑树的简易实现[cpp version]

数组版写着就是数服!

PS.是左偏红黑树

#include<bits/stdc++.h>
using namespace std;

namespace util {
    template<typename T>
    inline void empb(vector<T> &v,const T &arg) {
        v.emplace_back(arg);
    }
    inline void fast_io() { 
        ios::sync_with_stdio(0); 
        cin.tie(0); 
        cout.tie(0); 
    }
} 

using namespace util;

struct RBT {
    int root;
    vector<int> key,cnt,size,color,lc,rc;
    static const int BLACK = 0;
    static const int RED   = 1;

    RBT():root(0) {
        state(0,0);
        size[0] = cnt[0] = 0; 
    } 

    inline int state(int k,int c) {
        int n = lc.size();
        empb(key,k);
        empb(color,c);
        empb(cnt,1);
        empb(size,1);
        empb(lc,0);
        empb(rc,0);
        return n;
    }

    void add(int key) {
        if(!root) {
            int t = state(key,RED);
            root = t;
            return;
        }
        int t = add(root,key);
        root = t;
        color[root] = BLACK; // 1
    }

    inline void push_up(int cur) {
        if(cur) size[cur] = cnt[cur]+size[lc[cur]]+size[rc[cur]];
    }

    inline bool red(int cur) {
        return cur ? color[cur] == RED : false; 
    }

    inline int rl(int cur) {
        int fa = rc[cur];
        rc[cur] = lc[fa];
        lc[fa] = cur;

        color[fa] = color[cur];
        color[cur] = RED; // 2

        push_up(cur);
        push_up(fa);

        return fa;
    }

    inline int rr(int cur) {
        int fa = lc[cur];
        lc[cur] = rc[fa];
        rc[fa] = cur;

        color[fa] = color[cur];
        color[cur] = RED; // 2

        push_up(cur);
        push_up(fa);

        return fa;
    }

    inline void flip(int cur) {
        color[cur] ^= 1;
        color[lc[cur]] ^= 1;
        color[rc[cur]] ^= 1;
    }

    int add(int cur,int k) {
        if(!cur) {
            int t =state(k,RED);
            return t;
        }
        int cmp = k-key[cur];
        if(cmp < 0) lc[cur] = add(lc[cur],k);
        else if(cmp > 0) rc[cur] = add(rc[cur],k);
        else cnt[cur]++,size[cur]++;
        if(red(rc[cur]) && !red(lc[cur])) cur = rl(cur);
        if(red(lc[cur]) && red(lc[lc[cur]])) cur = rr(cur);
        if(red(lc[cur]) && red(rc[cur])) flip(cur);
        push_up(cur);
        return cur; 
    }

    bool remove(int k) {
        stack<int> stk;
        int cur = root;
        while(cur) {
            stk.push(cur);
            int cmp = k - key[cur];
            if(cmp < 0) cur = lc[cur];
            else if(cmp > 0) cur = rc[cur];
            else if(cnt[cur] == 0) return false;
            else {
                size[cur]--;
                cnt[cur]--;
                while(!stk.empty()) {
                    push_up(stk.top());
                    stk.pop();
                }
                return true;
            }
        }
        return false;
    }

    int rank(int k) {
        int cur = root;
        int ans = 0;
        while(cur) {
            int cmp = k - key[cur];
            if(cmp < 0) {
                cur = lc[cur];
            } else if(cmp > 0) {
                ans += size[lc[cur]]+cnt[cur];
                cur = rc[cur];
            } else {
                return ans+size[lc[cur]]+1;
            }
        }
        return 0;
    }

    int kth_node(int kth) {
        int cur = root;
        while(cur) {
            int sz = size[lc[cur]];
            if(sz >= kth) { //
                cur = lc[cur];
            } else if(kth-sz > cnt[cur]) {
                kth -= (sz+cnt[cur]);
                cur = rc[cur];
            } else {
                return cur;
            }
        }
        return 0;
    }

    int lower(int k) {
        return lower(root,k);
    }

    int lower(int cur,int k) {
        int ans = INT_MIN;
        while(cur) {
            int cmp = k - key[cur];
            if(cmp > 0 && key[cur] > ans) {
                if(cnt[cur] > 0) ans = key[cur];
                else return max(ans,max(lower(lc[cur],k),lower(rc[cur],k)));
            } else if(cmp > 0) {
                cur = rc[cur];
            } else {
                cur = lc[cur];
            }
        }
        return ans;
    }

    int upper(int k) {
        return upper(root,k);
    }

    int upper(int cur,int k) {
        int ans = INT_MAX;
        while(cur) {
            int cmp = k-key[cur];
            if(cmp < 0 && key[cur] < ans) {
                if(cnt[cur] > 0) ans = key[cur];
                else return min(ans,min(upper(lc[cur],k),upper(rc[cur],k)));
            } else if(cmp < 0) {
                cur = lc[cur];
            } else {
                cur = rc[cur];
            }
        }
        return ans;
    }
};

int main() {
    fast_io();
    RBT rbt;
    int n;
    cin >> n;
    for(int i = 1; i <= n; i++) {
        int op,key; 
        cin >> op >> key;
        switch(op) {
            case 1: rbt.add(key); break;
            case 2: rbt.remove(key); break;
            case 3: cout << rbt.rank(key) << endl; break;
            case 4: cout << rbt.key[rbt.kth_node(key)] << endl; break;
            case 5: cout << rbt.lower(key) << endl; break;
            case 6: cout << rbt.upper(key) << endl; break;
            default : break;
        }
    }
    return 0;
}

附加:一个跑的飞快但可能是假的size-balanced-tree https://paste.ubuntu.com/p/qSr8Bby3Mw/

发表评论

电子邮件地址不会被公开。 必填项已用*标注