Quick Sort 원리와 구현 (Python, Go)

Quick Sort는 분할 정복을 이용한 정렬 알고리즘으로 평균적으로 좋은 성능을 보인다. 

  • 평균 시간 복잡도: $O(nlog_2n)$
  • 정렬된 리스트에 대해: $O(n^2)$

정렬 원리 1

퀵 정렬의 아이디어는 생각보다 간단하다. 

  1. 기준을 정한다.
  2. 기준보다 작으면 기준점의 왼쪽으로, 기준보다 크면 기준점의 오른쪽으로 모은다.
  3. 기준의 왼쪽과 오른쪽의 배열을 쪼갤 수 없을 때까지 재귀적으로 정렬한다. 

이때 기준점을 pivot이라고 한다. 

 

배열의 첫번째 값을 pivot으로 설정한 예시:

1. 
[5 4 3 1 6 2 7]

2. pivot: 5
[4 3 1 2] + [5] + [6 7]

3. pivot: 4, 6
[3 1 2] + [4] + [5] + [6] + [7]

4. pivot: 3, (우측은 정렬 완료)
[1 2] + [3] + [4] + [5] + [6] + [7]

5. pivot: 1
[1] + [2] + [3] + [4] + [5] + [6] + [7]

-> [1 2 3 4 5 6 7]

 

정렬 원리 2

위와 같은 방식으로 좌우측에 배열을 생성해가며 정렬하면 구현하기 매우 간단하다. 하지만 문제는 메모리를 많이 사용한다는 것이다. 그래서 새로운 리스트를 생성하지 않고 pivot을 기준으로 나누는 방식 또한 사용한다. 

  1. 배열의 첫 값을 pivot으로 설정한다.
  2. 투 포인터를 이용해 값을 교체한다.
    1. 포인터1은 왼쪽에서부터 pivot보다 큰 값을 찾는다.
    2. 포인터2는 오른쪽에서부터 pivot보다 작은 값을 찾는다.
    3. 두 값의 위치를 바꾼다. 
  3. 만약 포인터가 엇갈렸다면 pivot과 포인터2 스왑
  4. 좌우측 배열에 대하여 재귀적으로 반복

예시: 

arr: [3, 2, 5, 1, 0, 7]

1. pivot 설정
pivot: 3 
[3] + [2, 5, 1, 0, 7]

2. 양쪽 끝에 포인터(p) 설정
p1: 2, p2: 7

2-1. pivot보다 큰 값이 나올 때까지 p1 이동
p1: 5 

2-2. pivot보다 작은 값이 나올 때까지 p2 이동
p2: 0

2-3. p1와 p2 스왑
p1: 5 , p2: 0
   [3] + [2, 5, 1, 0, 7]
-> [3] + [2, 0, 1, 5, 7]

2.
p1: 7, p2: 1

3. 포인터가 엇갈렸다면 pivot과 p2 스왑
   [3] + [2, 0, 1, 5, 7]
-> [1] + [2, 0, 3, 5, 7]
 = [1, 2, 0] + [3] + [5, 7]
 
4. 좌우 배열에 대하여 재귀적으로 반복

과정은 복잡해졌지만 새로운 공간을 사용하지 않고 초기에 설정한 배열의 메모리 내에서 값을 옮기며 정렬할 수 있게 되었다. 


구현 (Python, Go)

 

[ PyPy3 ] 정렬 원리 1

  • 백준: 148ms, 119036KB
def quick_sort(arr: list):
    if len(arr) <= 1:
        return arr

    pivot = arr.pop(0)
    left = [x for x in arr if x <= pivot]
    right = [x for x in arr if x > pivot]

    return quick_sort(left) + [pivot] + quick_sort(right)

if __name__ == "__main__":
    arr = [3, 5, 8, 6, 1, 2, 4, 7]
    arr = quick_sort(arr)
    print(arr)

 

[ PyPy3 ] 정렬 원리 2

  • 백준: 128ms, 11563KB
def quick_sort(arr, start=0, end=None):
  
    if end is None:
      end = len(arr) - 1
      
    if end <= start:
        return
    
    pivotIdx = start 
    pivot = arr[start]
    low  = start + 1
    high = end 

    while low <= high:
        while low <= end and arr[low] <= pivot:
            low += 1
        while start < high and pivot <= arr[high]:
            high -= 1

        if low > high:
            arr[high], arr[pivotIdx] = arr[pivotIdx], arr[high]
        else: 
            arr[low], arr[high] = arr[high], arr[low]

    quick_sort(arr, start, high-1)
    quick_sort(arr, high+1, end)

if __name__ == "__main__":
    arr = [3, 5, 8, 6, 1, 2, 4, 7]
    quick_sort(arr)
    print(arr)

2번 방법이 메모리도 적게 사용하고 실행 시간도 단축된 것을 볼 수 있다. 하지만 1번 방식이 읽기 훨씬 편하다. 

 

[ Go ] 정렬 원리 2

  • 백준: 4ms, 1064KB
package main

import "fmt"

func main() {
	arr := []int{3, 5, 8, 6, 1, 2, 4, 7}
	arr = quickSort(arr)
	fmt.Print(arr)
}

func quickSort(arr []int) []int {
	sorted := SortedArr{arr}
	lenArr := len(arr)
	sorted.Sort(0, lenArr-1)
	return sorted.GetSorted()
}

type SortedArr struct {
	arr []int
}

func (s *SortedArr) Sort(start, end int) {
	if end <= start {
		return
	}

	pivotIdx := start
	pivot := s.arr[start]
	low := start + 1
	high := end

	for low <= high {
		for low <= high && s.arr[low] <= pivot {
			low += 1
		}
		for start < high && pivot <= s.arr[high] {
			high -= 1
		}

		if high < low {
			s.arr[high], s.arr[pivotIdx] = s.arr[pivotIdx], s.arr[high]
		} else {
			s.arr[low], s.arr[high] = s.arr[high], s.arr[low]
		}
	}
	s.Sort(start, high-1)
	s.Sort(high+1, end)
}

func (s *SortedArr) GetSorted() []int {
	return s.arr
}

Python과 코드의 원리는 동일하지만 Go로 작성했을 때 압도적인 성능을 볼 수 있다.