Eu tenho um tensor p
de forma (B, 3, N)
no PyTorch:
# 2 batches, 3 channels (x, y, z), 5 points
p = torch.rand(2, 3, 5, requires_grad=True)
"""
p: tensor([[[0.8365, 0.0505, 0.4208, 0.7465, 0.6843],
[0.9922, 0.2684, 0.6898, 0.3983, 0.4227],
[0.3188, 0.2471, 0.9552, 0.5181, 0.6877]],
[[0.1079, 0.7694, 0.2194, 0.7801, 0.8043],
[0.8554, 0.3505, 0.4622, 0.0339, 0.7909],
[0.5806, 0.7593, 0.0193, 0.5191, 0.1589]]], requires_grad=True)
"""
E depois outro z_shift
de formato [B, 1]
:
z_shift = torch.tensor([[1.0], [10.0]], requires_grad=True)
"""
z_shift: tensor([[1.],
[10.]], requires_grad=True)
"""
Quero aplicar o deslocamento z apropriado de todos os pontos em cada lote, deixando x e y inalterados:
"""
p: tensor([[[0.8365, 0.0505, 0.4208, 0.7465, 0.6843],
[0.9922, 0.2684, 0.6898, 0.3983, 0.4227],
[1.3188, 1.2471, 1.9552, 1.5181, 1.6877]],
[[0.1079, 0.7694, 0.2194, 0.7801, 0.8043],
[0.8554, 0.3505, 0.4622, 0.0339, 0.7909],
[10.5806, 10.7593, 10.0193, 10.5191, 10.1589]]])
"""
Consegui fazer assim:
p[:, 2, :] += z_shift
para o caso em que requires_grad=False
, mas isso falha dentro do forward
meu nn.Module
(que eu presumo ser equivalente a requires_grad=True
) com:
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.