我有如下数据:
lf = pl.LazyFrame(
{
"points": [
[
[1.0, 2.0],
],
[
[3.0, 4.0],
[5.0, 6.0],
],
[
[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0],
],
],
"other": ["foo", "bar", "baz"],
},
schema={
"points": pl.List(pl.Array(pl.Float32, 2)),
"other": pl.String,
},
)
我想让所有列表都具有相同数量的元素。如果当前列表包含的元素多于我需要的元素,则应将其截断。如果列表包含的元素少于我需要的元素,则应按顺序重复执行,直到包含足够的元素。
我设法让它工作了,但我觉得我有点太过繁琐了。有没有更简洁的方法来实现这一点?也许可以用gather
?
target_length = 3
result = (
lf.with_columns(
needed=pl.lit(target_length).truediv(pl.col("points").list.len()).ceil()
)
.with_columns(
pl.col("points")
.repeat_by("needed")
.list.eval(pl.element().explode())
.list.head(target_length)
)
.drop("needed")
)
编辑
上述方法适用于玩具示例,但当我尝试在真实数据集中使用它时,它会失败:
pyo3_runtime.PanicException: Polars' maximum length reached. Consider installing 'polars-u64-idx'.
我还没能为此制作一个 MRE,但我的数据有 400 万行,每行的“点”列表有 1 到 8000 个元素(我试图填充/截断到 800 个元素)。这些元素看起来都很小,我不明白如何u32
达到最大长度。
我很感激我可以尝试的任何替代方法。
我最接近的(不惊慌)是:
但这不会按顺序填充重复列表。它只会填充重复的最后一个元素。
target_length = 3
result = (
lf.with_columns(
pl.col("points")
.list.gather(
pl.int_range(target_length),
null_on_oob=True,
)
.list.eval(pl.element().forward_fill())
)
.drop("needed")
)