- 2022.02.21 포스팅 기준 solved.ac 문제 등급 : 플래티넘 5
- 단계별로 풀어보기 -> 세그먼트 트리
- 알고리즘 분류
- 자료 구조
- 정렬
- 세그먼트 트리
- 분할 정복
<링크>
https://www.acmicpc.net/problem/1517
시간제한 | 메모리제한 |
1 초 | 512 MB |
<문제>
N개의 수로 이루어진 수열 A[1], A[2], …, A[N]이 있다. 이 수열에 대해서 버블 소트를 수행할 때, Swap이 총 몇 번 발생하는지 알아내는 프로그램을 작성하시오.
버블 소트는 서로 인접해 있는 두 수를 바꿔가며 정렬하는 방법이다. 예를 들어 수열이 3 2 1이었다고 하자. 이 경우에는 인접해 있는 3, 2가 바뀌어야 하므로 2 3 1 이 된다. 다음으로는 3, 1이 바뀌어야 하므로 2 1 3 이 된다. 다음에는 2, 1이 바뀌어야 하므로 1 2 3 이 된다. 그러면 더 이상 바꿔야 할 경우가 없으므로 정렬이 완료된다.
<입력>
첫째 줄에 N(1 ≤ N ≤ 500,000)이 주어진다. 다음 줄에는 N개의 정수로 A[1], A[2], …, A[N]이 주어진다. 각각의 A[i]는 0 ≤ |A[i]| ≤ 1,000,000,000의 범위에 들어있다.
<출력>
첫째 줄에 Swap 횟수를 출력한다
<예제 입력 1>
3
3 2 1
<예제 출력 1>
3
<풀이 접근>
이 문제의 풀이를 작성할 때, 정직하게 버블소트의 횟수를 세면서 세그먼트 트리를 갱신하면 시간 초과가 뜹니다. 입력의 개수가 50만이면서, 주어지는 정수의 최대 절댓값도 10억에 달하기 때문입니다. 따라서 이 문제는 앞서 다루었던 세그먼트 트리 문제들과 달리, 세그먼트 트리의 리프 노드가 입력 범위에 있는 전체 정수를 대변하지 않습니다.
이 문제에서 세그먼트 트리의 리프 노드들은 크기에 상관 없이, 입력받은 수열을 순서대로 저장합니다.
또한 이전의 문제들과 달리, 이번 문제에서는 세그먼트 트리의 각 노드의 값을 변경할 필요가 없습니다. 트리의 각 좌변 서브 트리와 우변 서브 트리를 비교하면, 버블 소트가 발생하는 횟수를 알 수 있기 때문에 트리를 작성하는 단계만 거치면 끝입니다. 이는 곧 병합 정렬(Merge sort)의 방식과 같다고 이해하셔도 될 듯합니다. 알고리즘 분류에 [분할 정복]이 있는 것이 이 때문인지도 모르겠습니다.
아래는 병합정렬을 위한 함수의 구현입니다.
def merge(a:list, b:list):
temp = []
count = 0
a_size = len(a)
i, j = 0, 0
while i < len(a) and j < len(b):
if a[i] <= b[j]:
temp.append(a[i])
i += 1
if a_size > 0:
a_size -= 1
else:
temp.append(b[j])
count += a_size
j += 1
while i < len(a):
temp.append(a[i])
i += 1
while j < len(b):
temp.append(b[j])
j += 1
return count, temp
우선, swap횟수를 저장하기 위한 count를 선언하고, 좌변 서브 트리 a[ ]의 크기를 a_size로 저장해둡니다.
정렬된 수를 저장하기 위해 빈 리스트 temp[ ]를 선언합니다.
a[ ]의 원소 각 a[i]마다 우변 서브 트리의 수열 b[ ]의 원소 b[j]와 비교합니다.
a[i]<=b[j] 이면 temp에 a[i]를 삽입하고, 이 수는 swap할 필요가 없으므로 a_size를 1 줄여준 후, a[ ]의 인덱스 번호 i를 1 증가시킵니다.
만약 a[i]>b[j] 이면 temp에 b[j]를 삽입하고, 인덱스 번호가 i 보다 크거나 같은 a[ ]의 원소들에는 b[j]와 정렬이 필요하므로 a_size를 count에 더해준 후 b[ ]의 인덱스 번호 j를 증가시킵니다.
어느 한 서브트리의 모든 원소에 대해 비교가 끝나면, 다른 서브 트리의 남은 원소들을 temp에 마저 삽입해줍니다.
정렬된 temp 리스트와 swap횟수를 저장한 count를 반환하면 부모 노드 이하 서브 트리의 정렬이 완료됩니다.
풀이의 자세한 구현은 아래의 코드를 참고하세요.
import sys
import math
input = sys.stdin.readline
n = int(input())
nums = [0] + [*map(int, input().split())]
M = max(nums)
seg_tree = [0 for _ in range((1 << math.ceil(math.log2(n + 1))+1) + 1)]
def merge(a, b):
temp = []
count = 0
a_size = len(a)
i, j = 0, 0
while i < len(a) and j < len(b):
if a[i] <= b[j]:
temp.append(a[i])
i += 1
if a_size > 0:
a_size -= 1
else:
temp.append(b[j])
count += a_size
j += 1
while i < len(a):
temp.append(a[i])
i += 1
while j < len(b):
temp.append(b[j])
j += 1
return count, temp
def make_tree(tree, start, end, v):
if start == end:
return 0, [nums[start]]
else:
mid = (start + end) // 2
left = make_tree(tree, start, mid, v*2)
right = make_tree(tree, mid+1, end, v*2+1)
count, merged = merge(left[1], right[1])
return left[0] + right[0] + count, merged
print(make_tree(seg_tree, 1, n, 1)[0])
위 풀이를 PyPy3로 제출한 결과는 아래와 같습니다. 113명중 30등입니다.
위 풀이를 Python3로 제출한 결과는 아래와 같습니다. 258명중 64등입니다.
이전에도 말씀 드렸다시피, Python언어 그룹에서 짧은 수행 시간을 중요하게 생각하신다면 PyPy3로 제출하시고, 메모리를 효율적으로 사용하고 싶으시다면 Python3를 사용하시는 걸 추천드립니다.
<마무리>
세그먼트 트리 문제이면서 병합정렬을 이용하여 풀 수 있는, 세그먼트 트리 자료구조와 분할 정복 알고리즘의 긴밀한 관계를 볼 수 있는 유익한 문제였던 것 같습니다.
개인적으로 구간합을 구하는 문제부터 느꼈지만, 세그먼트 트리는 마치 트리에 메모이제이션을 더한 자료구조라는 느낌이 듭니다.
'프로그래밍 > 백준 온라인 저지' 카테고리의 다른 글
[Python3] 5719 거의 최단 경로 - 플래티넘 5 (0) | 2022.04.25 |
---|---|
[Python3] 9345 디지털 비디오 디스크(DVDs) - 플래티넘 3 (0) | 2022.02.22 |
[Python3] 2357 최솟값과 최댓값 - 골드 1 (0) | 2022.02.17 |
[Python3] 11505 구간 곱 구하기 - 골드 1 (0) | 2022.02.11 |
[Python3] 2042 구간 합 구하기 - 골드 1 (0) | 2022.02.10 |