@huhl_2009 C++ 实现topk算子,要求能够通过阈值进行筛选,在arm上时间复杂度最低,最好能够并行优化
```C++
include <iostream>
include <vector>
include <algorithm>
include <queue>
include <thread>
// 并行排序线程函数
template<typename T>
void parallelSort(std::vector<T>& data, int start, int end) {
std::sort(data.begin() + start, data.begin() + end);
}
// 并行合并线程函数
template<typename T>
void parallelMerge(std::vector<T>& data, int start1, int end1, int start2, int end2, std::priority_queue<T>& result) {
int i = start1, j = start2;
while (i < end1 && j < end2) {
if (data[i] > data[j])
result.push(data[i++]);
else
result.push(data[j++]);
}
while (i < end1)
result.push(data[i++]);
while (j < end2)
result.push(data[j++]);
}
// 并行topk算子
template<typename T>
std::vector<T> parallelTopK(std::vector<T>& data, int k, T threshold) {
int dataSize = data.size();
int numThreads = std::thread::hardware_concurrency(); // 获取可用的CPU核心数
// 并行排序
std::vector<std::thread> sortThreads(numThreads);
int chunkSize = dataSize / numThreads;
int extraSize = dataSize % numThreads;
for (int i = 0; i < numThreads; ++i) {
int start = i * chunkSize;
int end = (i == numThreads - 1) ? (start + chunkSize + extraSize) : (start + chunkSize);
sortThreads[i] = std::thread(parallelSort<T>, std::ref(data), start, end);
}
for (int i = 0; i < numThreads; ++i)
sortThreads[i].join();
// 并行合并
std::priority_queue<T> result;
std::vector<std::thread> mergeThreads(numThreads);
int mergeSize = k / numThreads;
int mergeExtra = k % numThreads;
for (int i = 0; i < numThreads; ++i) {
int start = i * mergeSize;
int end = (i == numThreads - 1) ? (start + mergeSize + mergeExtra) : (start + mergeSize);
mergeThreads[i] = std::thread(parallelMerge<T>, std::ref(data), start, end, start, dataSize, std::ref(result));
}
for (int i = 0; i < numThreads; ++i)
mergeThreads[i].join();
// 提取topk结果
std::vector<T> topK;
while (!result.empty()) {
T element = result.top();
if (element >= threshold) // 根据阈值筛选
topK.push_back(element);
result.pop();
}
std::reverse(topK.begin(), topK.end()); // 由大到小排序
return topK;
}
int main() {
std::vector<int> data = {9, 4, 6, 2, 8, 5, 1, 7, 3};
int k = 4;
int threshold = 5;
std::vector<int> topK = parallelTopK(data, k, threshold);
std::cout << "Top " << k << " elements greater than or equal to " << threshold << ": ";
for (int i = 0; i < topK.size(); ++i)
std::cout << topK[i] << " ";
return 0;
}
```