Recentemente, PyTorch introduziu o tensor aninhado . No entanto, se eu criar um tensor aninhado, por exemplo,
import torch
a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
e então observe seu tipo de classe, ele mostra:
type(nt)
torch.Tensor
ou seja, o tipo de classe é apenas um PyTorch normal Tensor
. Então, type(nt) == torch.Tensor
e isinstance(nt, torch.Tensor)
ambos retornarão True
.
Então, minha pergunta é: existe uma maneira de diferenciar entre um tensor regular e um tensor aninhado?
Uma maneira que posso pensar é que o size
método para tensores aninhados (atualmente) funciona de maneira diferente daquele para tensores regulares, pois requer um argumento, caso contrário, gera a RuntimeError
. Então, uma solução pode ser:
def is_nested_tensor(nt):
if not isinstance(nt, torch.Tensor):
return False
try:
# try calling size without an argument
nt.size()
return False
except RuntimeError:
return True
return False
mas existe algo mais simples que não depende de algo como o size
método não mudar no futuro?
Existe um atributo para todos
torch.Tensor
os chamadosis_nested
, mas infelizmente não está documentado. É mencionado apenas no FAQ