Binary Search 的正確實作方法

Binary Search

對於有資訊背景的人一定都覺得 Binary Search 是基本概念。但越是簡單的概念,背後隱藏了越多你沒思考過的細節。

❝ Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky ❞

- Donald Knuth

Knuth 於 1998 年出版的 The Art of Computer Programming – Searching an ordered table 一節中就曾提及:雖然 binary search 是一個相對直覺的概念,但細節卻出乎意料的麻煩。

看看一般教科書的標準 binary search pseudocode:

BinarySearch(A[0..N-1], value) {
	low = 0
	high = N - 1
	while (low <= high) {
		// invariants: value > A[i] for all i < low 
		//             value < A[i] for all i > high
		mid = (low + high) / 2
		if (A[mid] > value)
			high = mid - 1
		else if (A[mid] < value)
			low = mid + 1
		else
			return mid
	}
	return not_found // value would be inserted at index "low"
}

基本就是很直覺的每次切半基本程式邏輯。但即使是 Java 自帶的 Binary Search 實作其實都曾經錯了九年之久。

Binary Search 一直是 LeetCode 題型變化的基礎之一。之前 刷 LeetCode 時,刷到一題 “Easy” 的 binary search 相關題目,依照教科書的 pseudocode 實作,理應一次提交完成了事,不料竟 TLE。

題目是 278. First Bad Version:給定 N 個軟體版本:1, 2, ….. N。找出第一個開始 broken 的版本。

不加思索立刻實作如下:

public int firstBadVersion(int n) {
    int left = 1;
    int right = n;

    while (left < right) {            
        int mid = (left + right) / 2;

        if (isBadVersion(mid)) {
            right = mid;
        } else {
            left = mid + 1;
        }
    }

    return left;        
}

如果一眼就能看出 overflow 的問題的話,就可以跳過後面不用看了。這個實作的問題就在 int mid = (left + right) / 2 這一行。這裡比較能直接找出問題的主因在於如果設計的 test cases 不夠完整,沒有考慮到 edge cases 的話就很容易忽略這個 overflow 的問題。Java int 本身是 32 bit 所以其最大值是 2^31 – 1 = 2,147,483,647,也就是說當 left + right 超過 2,147,483,647 時,這時 mid 就變成無法預期的負數,就可能產生無法跳出的迴圈。

只要稍微修改一下 int mid = left + (right – left) / 2 就搞定了。

至於前面提到 Java 錯了九年的 Binary Search 其實也是一樣的 overflow error。

public static int binarySearch(int[] a, int key) {
	int low = 0;
	int high = a.length - 1;

	while (low <= high) {
		int mid = (low + high) / 2;
		int midVal = a[mid];

		if (midVal < key)
			low = mid + 1
		else if (midVal > key)
			high = mid - 1;
		else
			return mid; // key found
	}
	return -(low + 1);  // key not found.
}

Java 的修正方法則是把取 mid 的方法改成:

int mid = (low + high) >>> 1;

這個實作使用的是 Logical Shift operator 直接忽略 sign bit,看起來是比較好一點的解法。

乍看之下可能覺得犯這樣的錯誤很不可思議,但其實很早就有過一篇研究發現 20 本教科書中僅 5 本是正確實作了 binary search

Binary Search 的 trick 也衍生了其變形的題目,比如說 LeetCode #34 Find First and Last Position of Element in Sorted Array 也是滿考驗臨場對於 binary search 的 indexing sense。

題目:給定一依照遞增排序的數列,找出 target value 的起始與結束位置。

比如輸入數列是 {5, 7, 7, 8, 8, 10}, target 是 8

那麼起始 index = 3,結束 index = 4

找起始 index 的方法是僅針對 mid value 大於 target value 的 case 遞增 index

int left = 0;
int right = nums.length;

while (left < right) {
    int mid = left + (right - left) / 2;

    if (target < nums[mid] || target == nums[mid]) {
        right = mid;
    } else {
        left = mid + 1;
    }
}
// left 就是起始 index,但還要額外考慮不存在的 case 

找結束 index 的方法則是對 mid value 大於等於 target value 的 case 遞增 index

int left = 0;
int right = nums.length;

while (left < right) {
    int mid = left + (right - left) / 2;

    if (target < nums[mid]) {
        right = mid;
    } else {
        left = mid + 1;
    }
}
// left - 1 就是結束 index 

這邊 tricky 的地方還有 right = nums.length,標準算法是 right = nums.length – 1。但因為這裡僅對單一邊遞增 index 所以 while loop 必須對應為拿掉 equal case (left <= right 改為 left < right) 否則會產生 infinite loop,因應這個 loop 條件要能考慮到最後一個 element 的話就必須將 right 改為 nums.length 而不是 nums.length – 1。完整 solution 放在 github 上

簡單總結一下,本質上對於 data type 表示的掌握以及 edge cases 的思考這樣的基本功還是必須時常維持著,才能夠在臨場完整解決基本算法變型的問題。





Leave a Reply

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *