본문 바로가기
JAVA/Java Study

[Java] 이진 탐색 트리 - 재귀 형태 구현

by ♡˖GYURI˖♡ 2024. 2. 5.

이진 탐색 트리에 대해 공부하다가, 재귀 형태 구현이 잘 이해되지 않아 정리해본다.

 

import java.util.LinkedList;
import java.util.Queue;

class BinarySearchTree {
    Node head;

    BinarySearchTree(int key) {
        this.head = new Node(key, null, null);
    }

    public Node addNodeRecursive(Node cur, int key) {
        if (cur == null) {
            return new Node(key, null, null);
        }

        if (key < cur.key) {
            cur.left = addNodeRecursive(cur.left, key);
        } else {
            cur.right = addNodeRecursive(cur.right, key);
        }

        return cur;
    }

    public Node removeNodeRecursive(Node cur, int key) {
        if (cur == null) {
            return null;
        }

        if (key < cur.key) {
            cur.left = removeNodeRecursive(cur.left, key);
        } else if (key > cur.key) {
            cur.right = removeNodeRecursive(cur.right, key);
        } else {
            if (cur.left == null) {
                return cur.right;
            } else if (cur.right == null) {
                return cur.left;
            } else {
                Node predecessor = cur;
                Node successor = cur.left;

                while (successor.right != null) {
                    predecessor = successor;
                    successor = successor.right;
                }

                predecessor.right = successor.left;
                cur.key = successor.key;
            }
        }
        return cur;
    }

    public void levelOrder(Node node) {
        Queue<Node> queue = new LinkedList();
        queue.add(node);
        while (!queue.isEmpty()) {
            Node cur = queue.poll();

            System.out.print(cur.key + " ");
            if (cur.left != null) {
                queue.offer(cur.left);
            }

            if (cur.right != null) {
                queue.offer(cur.right);
            }
        }
        System.out.println();
    }

}


public class Practice {
    public static void main(String[] args) {
        // Test code
        // 노드 삽입
        BinarySearchTree bst = new BinarySearchTree(20);
        bst.head = bst.addNodeRecursive(bst.head, 10);
        bst.head = bst.addNodeRecursive(bst.head, 30);
        bst.head = bst.addNodeRecursive(bst.head, 1);
        bst.head = bst.addNodeRecursive(bst.head, 15);
        bst.head = bst.addNodeRecursive(bst.head, 25);
        bst.head = bst.addNodeRecursive(bst.head, 13);
        bst.head = bst.addNodeRecursive(bst.head, 35);
        bst.head = bst.addNodeRecursive(bst.head, 27);
        bst.head = bst.addNodeRecursive(bst.head, 40);
        bst.levelOrder(bst.head);

        // 노드 삭제
        bst.head = bst.removeNodeRecursive(bst.head,40);
        bst.levelOrder(bst.head);
        bst.head = bst.removeNodeRecursive(bst.head, 25);
        bst.levelOrder(bst.head);
        bst.head = bst.removeNodeRecursive(bst.head, 20);
        bst.levelOrder(bst.head);
    }
}

 

 

먼저 노드를 추가하는 것부터 살펴보자.

public Node addNodeRecursive(Node cur, int key) {
    if (cur == null) {
        return new Node(key, null, null);
    }

    if (key < cur.key) {
        cur.left = addNodeRecursive(cur.left, key);
    } else {
    	cur.right = addNodeRecursive(cur.right, key);
    }

    return cur;
}
BinarySearchTree bst = new BinarySearchTree(20);
bst.head = bst.addNodeRecursive(bst.head, 10);
bst.head = bst.addNodeRecursive(bst.head, 30);
bst.head = bst.addNodeRecursive(bst.head, 1);
bst.head = bst.addNodeRecursive(bst.head, 15);
bst.head = bst.addNodeRecursive(bst.head, 25);
bst.head = bst.addNodeRecursive(bst.head, 13);
bst.head = bst.addNodeRecursive(bst.head, 35);
bst.head = bst.addNodeRecursive(bst.head, 27);
bst.head = bst.addNodeRecursive(bst.head, 40);

 

일단 bst를 생성하면서 넘긴 20이 루트 노드가 된다.

bst.head = bst.addNodeRecursive(bst.head, 10); 은 다음과 같이 작동한다.

  • 루트 노드인 20과, key 값 10을 넘긴다.
  • if (cur == null)로 루트 노드가 비어있는지 확인한다. → 비어있지 않으니 넘어간다.
  • if (key < cur.key)로 키 값이 루트 노드의 값보다 작은지 확인한다. → 10 < 20 이니 O
  • cur.left = addNodeRecursive(cur.left, key)로 재귀시킨다.
    • 루트 노드의 왼쪽은 현재 비어있으니 null이기에 addNodeRecursive(null, 10)이 된다.
    • if (cur == null) → O 이니 new Node(10, null, null)을 리턴한다.
  • cur.left = new Node(10, null, null) 이 된다.

 

여기까지만 해보면 아래와 같은 그림이 된다.

 

 

addNodeRecursive() 정리

  • 이진 탐색 트리가 비어있으면 → 바로 새 노드 만들어서 루트 노드로
  • 이진 탐색 트리가 비어있지 않으면 → key값과 노드의 값 비교
    • key < cur.key 이면 key값이 왼쪽에 할당되어야 하기 때문에 cur.left = addNodeRecursive(cur.left, key) 호출
      • cur == null일 때까지 반복 
    • key > cur.key 이면 key값이 오른쪽에 할당되어야 하기 때문에 cur.right= addNodeRecursive(cur.right, key) 호출
      • cur == null일 때까지 반복

 

 

addNodeRecursive()는 그나마 한 번에 이해가 됐는데, 문제는 removeNodeRecursive()였다.

 

public Node removeNodeRecursive(Node cur, int key) {
    if (cur == null) {
        return null;
    }

    if (key < cur.key) {
        cur.left = removeNodeRecursive(cur.left, key);
    } else if (key > cur.key) {
        cur.right = removeNodeRecursive(cur.right, key);
    } else {
        if (cur.left == null) {
            return cur.right;
        } else if (cur.right == null) {
            return cur.left;
        } else {
            Node predecessor = cur;
            Node successor = cur.left;

            while (successor.right != null) {
                predecessor = successor;
                successor = successor.right;
            }

            predecessor.right = successor.left;
            cur.key = successor.key;
        }
    }
    return cur;
}
bst.head = bst.removeNodeRecursive(bst.head,40);
bst.head = bst.removeNodeRecursive(bst.head, 25);
bst.head = bst.removeNodeRecursive(bst.head, 20);

 

 

각각은 다음과 같은 과정을 통해 진행된다.

 

키값이 40인 노드 지우기
키값이 25인 노드 지우기                                                                                 키값이 20인 노드 지우기

 

 

removeNodeRecursive() 정리

  • cur == null이면 바로 null을 리턴
  • if (key < cur.key) 이면 key값이 cur보다 왼쪽에 있을테니 removeNodeRecursive(cur.left, key) 호출
  • else if (key > cur.key) 이면 key값이 cur보다 오른쪽에 있을테니 removeNodeRecursive(cur.right, key) 호출
  • else 는 둘 다 아니라는 것이니 서로 같을 때!
    • if (cur.left == null) 이면 return cur.right;
      • 찾은 노드가 왼쪽 자식 노드가 없고 오른쪽만 있을 때, 오른쪽 자식 노드를 해당 노드의 자리로 이동
    • else if (cur.right == null) 이면 return cur.left;
      • 찾은 노드가 오른쪽 자식 노드가 없고 왼쪽만 있을 때, 왼쪽 자식 노드를 해당 노드의 자리로 이동
    • else 는 둘 다 아닐 때
      • predecessor = cur;
      • successor = cur.left;
      • successor.right != null일 때까지 while문 반복
        • predecessor = successor; 
        • successor = successor.right;
        • 찾는 노드의 왼쪽 서브 트리 중 가장 큰 수를 찾는 과정
      • predecessor.right = successor.left;
      • 왼쪽 서브 트리 중 가장 큰 수의 왼쪽 자식 노드를 왼쪽 서브 트리 중 가장 큰 수의 오른쪽 자식 노드 자리(왼쪽 서브 트리 중 가장 큰 수의 자리)로 이동
      • 찾은 노드의 자리에 왼쪽 서브 트리 중 가장 큰 수를 넣음

 

 

정리는 했는데 다시 구현해보라고 하면 손도 못댈 것 같은 느낌...🥹

그래도 이해는 했다...!