Pytorch 中的 topk 运算“返回给定输入张量沿给定维度的 k 个最大元素”。(从这里开始)。但它有相反的方法返回最小的 k 个元素吗?
前 k
x = torch.arange(1., 6.)
print(x)
>>> tensor([ 1., 2., 3., 4., 5.])
torch.topk(x, 3)
>>> torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
min k --> 这存在吗?
x = torch.arange(1., 6.)
print(x)
>>> tensor([ 1., 2., 3., 4., 5.])
torch.mink(x, 3) # doesn't run
>>> torch.return_types.mink(values=tensor([1., 2., 3.]), indices=tensor([0, 1, 2]))
我知道解决这个问题的方法是将张量乘以-1
:
torch.topk(-x, 3)
但想知道这个mink
操作是否已经存在。
检查文档页面https://pytorch.org/docs/stable/ generated/torch.topk.html
有一个关键字参数来获取最小值(和索引)
所以
mink
就是这样functools.partial(torch.topk, largest=False)