프로그래밍/백준 온라인 저지

[Python3] 12899 데이터구조 - 플래티넘 4

bright_P 2022. 2. 9. 15:38
  • 2022.02.09 포스팅 기준 solved.ac 문제 등급 : 플래티넘 4
  • 단계별로 풀어보기 -> 세그먼트 트리
  • 알고리즘 분류
    • 자료구조
    • 세그먼트 트리

 

<링크>
https://www.acmicpc.net/problem/12899

 

12899번: 데이터 구조

첫째 줄에 사전에 있는 쿼리의 수 N 이 주어집니다. (1 ≤ N ≤ 2,000,000) 둘째 줄부터 N개의 줄에 걸쳐 각 쿼리를 나타내는 2개의 정수 T X가 주어집니다. T가 1이라면 S에 추가할 X가 주어지는 것입니

www.acmicpc.net

 

시간제한 메모리제한
2 초 512 MB

 


<문제>
자연수를 저장하는 데이터베이스 S에 대해 다음의 쿼리를 처리합시다.
유형 1 : S에 자연수 X를 추가한다.
유형 2 : S에 포함된 숫자 중 X번째로 작은 수를 응답하고 그 수를 삭제한다.

<입력>
첫째 줄에 사전에 있는 쿼리의 수 N 이 주어집니다. (1 ≤ N ≤ 2,000,000)
둘째 줄부터 N개의 줄에 걸쳐 각 쿼리를 나타내는 2개의 정수 T X가 주어집니다.
T가 1이라면 S에 추가할 X가 주어지는 것입니다. (1 ≤ X ≤ 2,000,000)
T가 2라면 X는 S에서 삭제해야 할 몇 번째로 작은 수인지를 나타냅니다. S에 최소 X개의 원소가 있음이 보장됩니다.

<출력>
유형 2의 쿼리 개수만큼의 줄에 각 쿼리에 대한 답을 출력합니다.

 

<예제 입력 1>

5
1 11
1 29
1 89
2 2
2 2

<예제 출력 1>

29
89

 


<풀이 접근>

입력받는 정수의 범위가 1 이상 200만 이하로, 그 범위가 매우 큽니다.
따라서 세그먼트 트리의 크기는 (200만 보다 큰 2의 제곱수) * 2 이상이어야 합니다.
이 크기는 곧 2**ceil(log2(2000001))로 나타낼 수 있고, 어림잡아 계산한다면 200만 * 4로 나타내도 무방하다고 합니다.
(단, ceil은 소수의 올림을 뜻합니다)

더 효율적인 연산을 위해 비트 쉬프트 연산으로 표현하면 아래와 같습니다.

import math

seg_tree = [0 for _ in range((1 << math.ceil(math.log2(size + 1))+1) + 1)]

트리 인덱스 계산의 편의상, 0을 비워두고 1부터 사용하기 위해 range의 범위 끝에서 1을 더해줍니다.
이렇게 하면, 인덱스가 1부터 시작하는 트리에서 n인 노드의 자식 노드들을 각각 n*2n*2+1로 나타낼 수 있습니다.


세그먼트 트리의 리프 노드를 제외한 각 노드의 값은, 해당 노드가 루트 노드인 서브 트리에 들어있는 정수의 개수입니다.
세그먼트 트리의 각 리프노드는 순서대로 1부터 200만까지이고, t=1인 쿼리의 x에 해당하는 정수가 입력된 횟수를 저장합니다.
각 리프노드의 값은 t=2인 쿼리의 x에 대해서, 현재 세그먼트 트리에 저장된 수들 중 x번째로 작은 정수에 해당할 때 1씩 감소합니다.

t=1인 쿼리가 들어올 때는 insert 함수를, t=2인 쿼리가 들어올 때는 pop함수를 수행하도록 코드를 작성하겠습니다.

<풀이 1>

import sys
import math

# 최대 입력의 크기가 200만으로 매우 크므로, 속도가 느린 기본 input()을 대신해 sys.stdin.readline을 사용합니다.
input = sys.stdin.readline

n = int(input())

# 최대 입력 크기인 200만에 대응할 수 있도록 세그먼트 트리의 크기를 잡습니다.
size = 2000000  
seg_tree = [0 for _ in range((2 << int(math.ceil(math.log2(size + 1)))+1) + 1)]


# insert함수를 이용해 t=1인 쿼리가 주어질 때 마다, 
# v=1인 루트노드부터 정수 x에 해당하는 리프노드까지 세그먼트 트리를 갱신합니다.

def insert(tree, start, end, v, x):
    tree[v] += 1
    if start==end:
        return
    mid = (start + end) // 2
    if x <= mid:
        insert(tree, start, mid, v*2, x)
    else:
        insert(tree, mid+1, end, v*2+1, x)
    return


# pop함수를 이용해 t=2인 쿼리가 주어질 때 마다, 
# 세그먼트 트리에서 x번째로 작은 수를 찾아가며 각 노드의 값을 감소시키고, 
# 최종 노드에서 그 수를 반환합니다.

def pop(tree, start, end, v, x):
    if tree[v]<1:
        raise RuntimeError
    tree[v] -= 1
    if start == end:
        return start
    mid = (start + end) // 2
    if x <= tree[v*2]:
        return pop(tree, start, mid, v * 2, x)
    else:
        return pop(tree, mid + 1, end, v * 2 + 1, x-tree[v*2])


# 함수의 정의가 끝났으니, n개의 쿼리를 받고 수행합니다.
for T, X in [[*map(int, input().split())] for _ in range(n)]:
    if T == 1:
        insert(seg_tree, 1, size, 1, X)
    else:
        print(pop(seg_tree, 1, size, 1, X))

위 풀이는 재귀함수를 이용해 구현한 것입니다.
위와 같은 풀이로는 아래의 결과와 같이, 꽤 느린 수행 속도를 보게 됩니다.
PyPy3로 제출해야 통과되고, Python3로는 시간 초과가 나옵니다.

<풀이 2>

import sys
import math

input = sys.stdin.readline

n = int(input())

size = 2000000
seg_tree = [0 for _ in range((1 << math.ceil(math.log2(size + 1))+1) + 1)]


# <풀이 1>에서와 달리, 같은 함수를 while loop를 이용해 작성합니다.

def insert(tree, start, end, v, x):
    while start!=end:
        tree[v] += 1
        mid = (start + end) // 2
        if x <= mid:
            end = mid
            v = v*2
        else:
            start = mid+1
            v = v*2+1
    tree[v] += 1
    return


def pop(tree, start, end, v, x):
    while start != end:
        tree[v] -= 1
        mid = (start + end) // 2
        if x <= tree[v*2]:
            end = mid
            v = v * 2
        else:
            start = mid + 1
            x -= tree[v * 2]
            v = v * 2 + 1
    tree[v] -= 1
    return start


for T, X in [[*map(int, input().split())] for _ in range(n)]:
    if T == 1:
        insert(seg_tree, 1, size, 1, X)
    else:
        print(pop(seg_tree, 1, size, 1, X))​

위 풀이는 재귀함수를 사용하지 않고 while loop를 사용한 것입니다.
아래의 결과와 같이, 절반 정도로 수행 시간을 줄일 수 있습니다.
하지만 <풀이 1>과 마찬가지로, Python3로 제출하면 시간 초과가 나옵니다.

 

 


<마무리>
Python 언어그룹의 정답자들을 보면 짧은 코드가 600바이트 내외, 빠른 코드가 2000ms 내외로 나옵니다.
위 풀이에 더 효율적인 코드 스타일을 찾아 고칠 수 있을 것으로 보입니다.

이 문제 밖에도 Python3에서는 재귀함수를 이용하는 것 보다 while문을 이용하는것이 훨씬 빠를 때가 있습니다.
저는 BFS를 구현할때에도 재귀 대신 while문을 사용합니다.
적절하게 사용하면 좋은 테크닉일 듯 합니다.