Eu tenho uma matriz numpy tridimensional. Qual é a maneira mais rápida de obter um array 3D que tenha o maior item de cada eixo final do array sem escrever um loop. (Mais tarde usarei o CuPy com a mesma sintaxe, e os loops tirariam o paralelismo da GPU e a velocidade qual é o fator mais importante aqui.)
Obter os índices dos itens maiores é fácil:
>>> arr = np.array(
[[[ 6, -2, -6, -5],
[ 1, 12, 3, 9],
[21, 7, 9, 8]],
[[15, 12, 20, 12],
[17, 15, 17, 23],
[22, 18, 27, 32]]])
>>> indexes = arr.argmax(axis=2, keepdims=True)
>>> indexes
array([[[0],
[1],
[0]],
[[2],
[3],
[3]]])
mas como usar esses índices para obter os valores escolhidos em arr? Todas as maneiras que tentei produziram erros (como arr[indexes]) ou resultados errados. O que eu gostaria de obter neste exemplo é
array([[[6],
[12],
[21]],
[[20],
[23],
[32]]])
Eu acho que você pode usar np.amax para isso
Saída
Conforme apontado por hpaulj, você pode usar
take_along_axis