当前位置: 首页 > >

Quicksort变体,Quickselect寻找数组中位数,使用median-of-median of five策略挑选pivot,复杂度O(N)

发布时间:

算法说明:
QuickSort的递归式为T(N) = 2T(N/2) + O(N),即对原数组进行遍历分类(复杂度O(N))并对小于等于pivot和大于等于pivot的两个组别继续排序,复杂度解得为T(N) = O(NlogN);


QuickSelect脱胎于QuickSort,由于只需要获取kth smallest/largest的一个元素(包括中位数,定义为第(size+1)/2个最小元素),因此在遍历分类后只需要取一侧继续分类处理即可,因此递归式为T(N) = T(N/2) + O(N),即可达到T(N) = O(N)的复杂度;


但是问题在于如何保证T(N/2),我们必须保证每次分类后的子数组元素数目为父数组元素的某一个比例,这个比例可以是理想的1/2,当然也可以是2/3, 4/5等等,只要能保证一个最大的比例值,就可以获得O(N)的复杂度(当然系数不同);
可以这么理解:若能保证一个最大的比例就可以保证每次子数组的size必然成比例的缩小,而不会产生pivot为较大值或较小值导致一次只分出去几个元素的低效情况;


这也就是median-of-median of five的策略的用处,该策略的目的是获取一个满足上面要求的pivot,该策略示意图如下:

即我们将数组分组,5个元素为一组,然后求取每组的中位数,并对这ceil(N/5)个中位数再次求取中位数作为pivot;


原理上获取pivot = v后(假设一共45个元素分为9组),则可以判断的是大于v的数字有:
4(9个中位数中的4个)+2(v所在5元素组的2个)+8(9个中位数中4个比v大的L元素所在5元素组的一共4*2 = 8个) = 14个;
同理可知比v小的数字有14个;
因此以v = pivot分组必然可获得14~31的分组,也就是70%的比例;


此时QuickSelect可以保证O(N),但是寻找pivot的过程也需要保证O(N):
对于5元素数组寻找中位数可以直接选择排序,因为只有5个元素,所以经过常数次比较必然可以得到中位数,因而找出所有5元素子数组的复杂度为O(N);
对于中位数的数组,可以递归循环该过程,因此总体复杂度为O(N+N/5+N/25+…) = O(N);


下面是实现:
首先是相关函数:
包括插入排序,返回pivot的median_m5函数,quickselect函数,和产生随机序列的函数;
这里一方面需要注意在原数组中定位pivot,因为median_m5函数较难保持返回值的位置信息,所以采取了遍历的方式,复杂度O(N),不影响整体复杂度;
另外注意QuickSelect中需要在两端遍历时更新i,j(所以初期要将i,j移动到左右两端之外),因为如果在交换时更新会产生虽然i≥j,但是n[i]可能小于pivot的情况,如此顺序会有误;


#include
#include
using std::vector, std::swap;
void insertionsort(vector &n, int lc, int rc)
{
int i, j;
for (i = lc+1; i <= rc; ++i) {
int tmp = n[i];
for (j = i - 1; j >= lc; --j) {
if (tmp < n[j])
n[j+1] = n[j];
else
break;
}
n[j+1] = tmp;
}
}
int median_m5(vector &n, int lc, int rc)
{
if (rc-lc+1 <= 5)
{
insertionsort(n, lc, rc);
int idx = (lc+rc)/2;
return n[idx];
}
vector medians, tmp;
for (; rc-lc+1 >= 5; lc += 5)
medians.push_back(median_m5(n,lc,lc+4));
if (lc != rc)
medians.push_back(median_m5(n, lc, rc));
return median_m5(medians, 0, medians.size()-1);
}
int quickselect(vector &n, int lc, int rc, int k)
{
int pivot = median_m5(n, lc, rc);
for (int i = lc; i <= rc; ++i)
if (n[i] == pivot)
{
swap(n[i], n[rc]);
break;
}
int i = lc-1, j = rc;
while (i < j) {
while (++i <= rc && n[i] < pivot)
;
while (--j >= lc && n[j] > pivot)
;
if (i < j)
swap(n[i], n[j]);
}
swap(n[i], n[rc]);
if (i == k-1)
return pivot;
else if (i < k-1)
return quickselect(n, i+1, rc, k);
else
return quickselect(n, lc, i-1, k);
}
void GenerateRandomNumber(vector &container, int n)
{
std::random_device rd;
std::default_random_engine rng(rd());
std::uniform_int_distribution dist(0, 50);
for (int i = 0; i < n; ++i)
container.push_back(dist(rng));
}

测试程序如下:
其中后面使用排序方式确认中位数;
注意QuickSelect接受数组,左右端index和k值(k表示第k个数字,从1开始);


int main() {
vector numbers;
GenerateRandomNumber(numbers, 12);
for (auto &num : numbers)
cout << num << " ";
cout << endl;
cout << quickselect(numbers, 0, numbers.size()-1,
(numbers.size()+1)/2) << endl;

sort(numbers.begin(), numbers.end(), less());
for (auto &num : numbers)
cout << num << " ";
cout << endl << numbers[(numbers.size()+1)/2 - 1] << endl;

return 0;
}



友情链接: