Preciso classificar um lote de linhas da matriz 2D pelo valor-chave da primeira coluna:
matrizes de lote originais (tensor 3d):
torch.tensor([[[2, 0],
[0, 1],
[1, 2]],
[[1, 2],
[0, 0],
[2, 1]]])
tensor desejado:
torch.tensor([[[0, 1],
[1, 2],
[2, 0]],
[[0, 0],
[1, 2],
[2, 1]]])
Já sei como lidar com um dos lotes , e outra resposta resolve o problema pelo loop for, que não é paralelo. Então, como lidar com todo o lote paralelamente?
Isso pode ser um pouco confuso, mas faz sentido:
Na primeira linha você extrai o tensor de classificação pensado
torch.argsort
e o aplicamy_tensor
, resultando em um(2, 2, 3, 2)
tensor de forma. Como você deseja que cada elemento seja classificado apenas de acordo com sua primeira coluna, você está interessado apenas na diagonal das duas primeiras dimensões e pode extraí-la fatiando (segunda linha do código).