红黑树的简易实现

肝了好几个小时的成品

大概通过了洛谷11/12个测试点(其中一个TLE,时限开的太紧了)

简要说明:

1.add的旋转分类参考自算法第4版实现(大概就是右red/左两个red/一左一右red)

2.删除使用了LAZY标记,分类删除?不存在的

题面

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

插入x数

删除x数(若有多个相同的数,因只删除一个)

查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)

查询排名为x的数

求x的前驱(前驱定义为小于x,且最大的数)

求x的后继(后继定义为大于x,且最小的数)

输入格式

第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号( 1≤opt≤6 )

输出格式

对于操作3,4,5,63,4,5,63,4,5,6每行输出一个数,表示对应答案

输入样例

10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598

输出样例

106465
84185
492737
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;

enum Color {
    RED,
    BLACK;
}

class Node implements Comparable<Node> {

    public int key;
    public int size;
    public int cnt;
    public Color color;
    public Node lc,rc;

    public Node(int key,int size,Color color) {
        this.key = key;
        this.size = size;
        this.color = color;
        this.cnt = 1;
    }

    /** 可以把int key改为Comparable **/
    public int compareTo(Node that) {
        return this.key - that.key;
    }

    public String toString() {
        return ""+key;
    }

}
class RedBlackTree {

    public Node root;

    private static Color RED = Color.RED;
    private static Color BLACK = Color.BLACK;

    /** rotate right **/
    private Node rr(Node n) {
        Node fa = n.lc;
        n.lc = fa.rc;
        fa.rc = n;

        fa.color = n.color;
        n.color = RED;

        pushUp(n);
        pushUp(fa);
        return fa;
    }

    /** rotate left **/
    private Node rl(Node n) {
        Node fa = n.rc;
        n.rc = fa.lc;
        fa.lc = n;

        fa.color = n.color;
        n.color = RED;

        pushUp(n);
        pushUp(fa);
        return fa;
    }

    private Color flippedColor(Node n) {
        return isRed(n) ? BLACK : RED;
    }

    private void flip(Node n) {
        n.color = flippedColor(n);
        n.lc.color = flippedColor(n.lc);
        n.rc.color = flippedColor(n.rc);
    }

    private boolean isRed(Node n) {
        return n != null && n.color == RED;
    }

    private int size(Node n) {
        return n == null ? 0 : n.size;
    }

    private void pushUp(Node n) {
        if(n == null) return;
        n.size = n.cnt + size(n.lc) + size(n.rc);
    }

    public void add(int key) {
        if(root == null) {
            root = new Node(key,1,RED);
            return;
        }
        root = add(root,key);
        root.color = BLACK;
    }
    private Node add(Node n,int key) {
        if(n == null) return new Node(key,1,RED);
        int cmp = key-n.key;
        if(cmp < 0) n.lc = add(n.lc,key);
        else if(cmp > 0) n.rc = add(n.rc,key);
        else {
            n.cnt++;
            n.size++;
            pushUp(n);
            return n;
        }

        if(isRed(n.rc) && !isRed(n.lc)) n = rl(n);
        if(isRed(n.lc) && isRed(n.lc.lc)) n = rr(n);
        if(isRed(n.lc) && isRed(n.rc)) flip(n);

        pushUp(n);
        return n;
    }

    public boolean get(int key) {
        Node cur = root;
        while(cur != null) {
            int cmp = key-cur.key;
            if(cmp < 0) cur = cur.lc;
            else if(cmp > 0) cur = cur.rc;
            else return true;
        }
        return false;
    }

    public boolean remove(int key) {
        List<Node> stack = new ArrayList<>();
        Node cur = root;
        while(cur != null) {
            int cmp = key-cur.key;
            stack.add(cur);
            if(cmp < 0) cur = cur.lc;
            else if(cmp > 0) cur = cur.rc;
            else {
                if(cur.cnt == 0) return false;
                cur.size--;
                cur.cnt--;
                for(int i = stack.size()-1; i >= 0; --i) {
                    pushUp(stack.get(i));
                }
                return true;
            }
        }
        return false;
    }

    public Node findNode(int key) {
        Node cur = root;
        while(cur != null) {

            int cmp = key-cur.key;
            if(cmp < 0) cur = cur.lc;
            else if(cmp > 0) cur = cur.rc;
            else break;
        }
        return cur;
    }

    public int getRank(int key) {
        Node cur = root;
        int tmp = 0;
        while(cur != null) {
            int cmp = key-cur.key;
            if(cmp < 0) {
                cur = cur.lc;
            } else if(cmp > 0) {
                tmp += cur.cnt + size(cur.lc);
                cur = cur.rc;
            } else {
                return size(cur.lc)+1 + tmp;
            }
        }
        return 0;
    }

    public Node getKth(int kth) {
        Node cur = root;
        while(cur != null) {

            int sz = size(cur.lc);
            if(sz >= kth) {
                cur = cur.lc;
            } else if(kth-sz > cur.cnt) {
                kth -= (sz + cur.cnt);
                cur = cur.rc;
            } else {
                return cur;
            }
        }
        return null;
    }

    /** 小于key且最大的数 **/
    public int lower(int key) {
        return lower(root,key);
    }

    public int lower(Node cur,int key) {
        int result = Integer.MIN_VALUE;
        if(cur == null) return result;
        while(cur != null) {
            int cmp = key-cur.key;
            if(cmp > 0 && cur.key > result) {
                if(cur.cnt > 0) result = cur.key;
                else return Math.max(result,Math.max(lower(cur.lc,key),lower(cur.rc,key)));
            } else if(cmp > 0) {
                cur = cur.rc;
            } else {
                cur = cur.lc;
            }
        }
        return result;
    }

    /** 大于key且最小的数 **/
    public int upper(int key) {
        return upper(root,key);
    }

    /** 被LAZY删除的数可能会导致死循环,需要分支遍历 **/
    public int upper(Node cur,int key) {
        int result = Integer.MAX_VALUE;
        while(cur != null) {
            int cmp = key-cur.key;
            if(cmp < 0 && cur.key < result) {
                if(cur.cnt > 0) result = cur.key;
                else return Math.min(result,Math.min(upper(cur.lc,key),upper(cur.rc,key))); // 注意
            } else if(cmp < 0) {
                cur = cur.lc;
            } else {
                cur = cur.rc;
            }
        }
        return result;
    }

    public void dfs(Node n) {
        if(n == null) return;
        dfs(n.lc);
        if(n.cnt > 0) System.out.print(n.key+" ");
        dfs(n.rc);
    }

    public void dfs2(Node n) {
        if(n == null) return;
        if(n.cnt > 0) System.out.print(n.key+"[size"+n.size+"] ");
        dfs2(n.lc);
        dfs2(n.rc);
    }

}

public class Main {

    /** 洛谷的时限卡的比较严,需要开个小挂 **/
    static class InputReader {
        public BufferedReader reader;
        public StringTokenizer tokenizer;

        public InputReader(InputStream stream) {
            reader = new BufferedReader(new InputStreamReader(stream), 32768);
            tokenizer = null;
        }

        public String next() {
            while (tokenizer == null || !tokenizer.hasMoreTokens()) {
                try {
                    tokenizer = new StringTokenizer(reader.readLine());
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            return tokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(next());
        }

    }

    public static void main(String[] args)  {
        RedBlackTree rbt = new RedBlackTree();
        InputReader sc = new InputReader(System.in);
        int n = sc.nextInt();
        for(int i = 1; i <= n; i++) {
            int op = sc.nextInt();
            int key = sc.nextInt();
            if(op == 1) {
                rbt.add(key);
            } else if(op == 2) {
                rbt.remove(key);
            } else if(op == 3) {
                System.out.println(rbt.getRank(key));
            } else if(op == 4) {
                System.out.println(rbt.getKth(key));
            } else if(op == 5) {
                System.out.println(rbt.lower(key));
            } else {
                System.out.println(rbt.upper(key));
            }
        }
    }
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注