我有一个张量
import torch
a = torch.randn(1, 3, requires_grad=True)
print('a: ', a)
>>> a: tensor([[0.0200, 1.00200, -4.2000]], requires_grad=True)
还有一个面具
mask = torch.zeros_like(a)
mask[0][0] = 1
我想屏蔽我的张量a
而不将梯度传播到我的掩模张量(在我的实际情况中它有一个梯度)。我尝试了以下操作
with torch.no_grad():
b = a * mask
print('b: ', b)
>>> b: tensor([[0.0200, 0.0000, -0.0000]])
但它完全从我的张量中删除了梯度。正确的做法是什么?
您可以调用
detach
掩模张量将其从梯度链中删除。