我有两个数据框:a (~600M 行)和b (~2M 行) 。当在相应列上使用 1 个相等条件和2 个不等条件时,将 b 连接到 a 的最佳方法是什么?
- a_1=b_1
- a_2 >= b_2
- a_3 >= b_3
我目前探索了以下路径:
- 極色:
- join_asof():仅允许 1 个不等式条件
- join_where() 与 filter():即使容差窗口较小,标准 Polars 安装在连接期间也会用尽行数(4.3B 行限制),并且 polars-u64-idx 安装会耗尽内存(512GB)
- DuckDB:ASOF LEFT JOIN:也只允许 1 个不平等条件
- Numba:由于上述方法不起作用,我尝试创建自己的 join_asof() 函数 - 请参阅下面的代码。它工作正常,但随着 a 的长度增加,它变得非常慢。我尝试了各种不同的 for/while 循环和过滤配置,所有结果都相似。
现在我有点想不出主意了...有什么更有效的方法来实现这一点?
谢谢
import numba as nb
import numpy as np
import polars as pl
import time
@nb.njit(nb.int32[:](nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:]), parallel=True)
def join_multi_ineq(a_1, a_2, a_3, b_1, b_2, b_3, b_4):
output = np.zeros(len(a_1), dtype=np.int32)
for i in nb.prange(len(a_1)):
for j in range(len(b_1) - 1, -1, -1):
if a_1[i] == b_1[j]:
if a_2[i] >= b_2[j]:
if a_3[i] >= b_3[j]:
output[i] = b_4[j]
break
return output
length_a = 5_000_000
length_b = 2_000_000
start_time = time.time()
output = join_multi_ineq(a_1=np.random.randint(1, 1_000, length_a, dtype=np.int32),
a_2=np.random.randint(1, 1_000, length_a, dtype=np.int32),
a_3=np.random.randint(1, 1_000, length_a, dtype=np.int32),
b_1=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_2=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_3=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_4=np.random.randint(1, 1_000, length_b, dtype=np.int32))
print(f"Duration: {(time.time() - start_time):.2f} seconds")
您可以使用 DuckDB(Postgresql)
distinct on
子句:您也可以尝试使用,
pl.DataFrame.join_where()
但采用懒惰模式。我假设您的“a”数据框具有唯一键,在此示例中为 -a1,a2,a3
。pl.DataFrame.lazy()
将 DataFrame 视为 LazyFrame。pl.LazyFrame.join_where()
将惰性框架连接在一起。pl.LazyFrame.sort()
对结果进行排序。pl.LazyFrame.drop()
删除b2,b3
列。pl.LazyFrame.unique()
每行只留一行a1,a2,a3
。pl.LazyFrame.collect()
。如果这些都不起作用,您可以尝试用将其中一个帧分成 N 个块
pl.DataFrame.partition_by()
,分别处理块,然后使用pl.concat()
将它们连接回来。在这里使用 Numba 是个好主意,因为该操作特别昂贵。话虽如此,
O(n²)
但算法的复杂性却很难做得更好(而不会使代码变得更加复杂)。此外,数组b_1
可能不适合 L3 缓存,它被完全读取了 5_000_000 次,这使得代码相当受内存限制。我们可以通过构建索引来大大加快代码速度,这样就不必遍历整个数组
b_1
,而只需遍历值a_1[i] == b_1[j]
。这不足以改善复杂性,因为很多j
值都满足这个条件。我们可以通过为索引的所有节点构建一种树来改善(平均)复杂性,但在实践中,这会使代码变得更加复杂,构建树的时间会很长,以至于实际上不值得这样做。事实上,一个基本的索引足以大大减少提供的随机数据集(具有均匀分布的数字)上的执行时间。以下是生成的代码:请注意,平均而言,只
k
测试了 32 个值(这足够合理,不需要构建更高效/更复杂的数据结构)。还请注意,结果与简单实现提供的结果完全相同。基准
以下是我的 i5-9600KF CPU(6 核)上的结果:
因此,该实现比初始代码快约30 倍。