最近,PyTorch 引入了嵌套张量。但是,如果我创建一个嵌套张量,例如,
import torch
a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
然后查看它的类类型,它显示:
type(nt)
torch.Tensor
即,类类型只是常规的 PyTorch Tensor
。所以type(nt) == torch.Tensor
和isinstance(nt, torch.Tensor)
都会返回True
。
所以,我的问题是,有没有办法区分常规张量和嵌套张量?
我能想到的一种方法是,size
嵌套张量的方法(当前)与常规张量的工作方式不同,因为它需要一个参数,否则它会引发一个RuntimeError
. 因此,解决方案可能是:
def is_nested_tensor(nt):
if not isinstance(nt, torch.Tensor):
return False
try:
# try calling size without an argument
nt.size()
return False
except RuntimeError:
return True
return False
但是有没有更简单的东西不依赖于size
未来不会改变的方法之类的东西?
torch.Tensor
所有s都有一个名为 的属性is_nested
,但遗憾的是它没有记录。仅在常见问题解答中提到