Tenho uma situação em que preciso adicionar um tensor PyTorch a partes de outro tensor. Um exemplo é este:
import torch
x = torch.randn([10, 7, 128, 128]) # [batch, channel, height, width]
# In the actual program, batch_idx and channel_idx are generated dynamically
batch_idx = torch.tensor([1,3], dtype=torch.int64)
channel_idx = torch.tensor([2,3,5], dtype=torch.int64)
y = torch.randn([2, 3, 128, 128]) # [len(batch_idx), len(channel_idx), height, width]
x[batch_idx, channel_idx, :, :] += y
A execução deste código gera o seguinte erro:
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]
Como posso executar a operação desejada sem executar um loop em cada índice de cada dimensão?
O PyTorch espera que batch_idx e channel_idx possam ser transmitidos juntos, mas no seu caso batch_idx tem a forma [2] e channel_idx tem a forma [3], que não podem ser transmitidos diretamente.
Você pode tentar usar torch.meshgrid junto com indexação avançada:-
import torch
x = torch.randn([10, 7, 128, 128]) # [batch, channel, height, width]
batch_idx = torch.tensor([1, 3], dtype=torch.int64)
channel_idx = torch.tensor([2, 3, 5], dtype=torch.int64)
y = torch.randn([2, 3, 128, 128]) # [len(batch_idx), len(channel_idx), height, width]
# Create a meshgrid of batch and channel indices.
b_idx, c_idx = torch.meshgrid(batch_idx, channel_idx, indexing='ij') # shapes: [2, 3]
# Use the meshgrid to index x and then add y:-
x[b_idx, c_idx, :, :] += y