Ainda estou pegando o jeito das melhores práticas em jax
. Minha pergunta geral é a seguinte:
Quais são as melhores práticas para a implementação de rotinas de criação de matrizes personalizadas em jax
?
Por exemplo, eu quero implementar uma função que cria uma matriz com zeros em todos os lugares, exceto com uns em uma coluna dada. Eu fui para isso (Jupyter notebook):
import numpy as np
import jax.numpy as jnp
def ones_at_col(shape_mat, idx):
idxs = jnp.arange(shape_mat[1])[None,:]
mat = jnp.where(idx==idxs, 1, 0)
mat = jnp.repeat(mat, shape_mat[0], axis=0)
return mat
shape_mat = (5,10)
print(ones_at_col(shape_mat, 5))
%timeit np.zeros(shape_mat)
%timeit jnp.zeros(shape_mat)
%timeit ones_at_col(shape_mat, 5)
A saída é
[[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]]
127 ns ± 0.717 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
31.3 µs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
123 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Minha função é um fator de 4 mais lenta que a jnp.zeros()
rotina, o que não é tão ruim. Isso me diz que o que estou fazendo não é loucura.
Mas então ambas jax
as rotinas são muito mais lentas do que as numpy
rotinas equivalentes. Essas funções não podem ser jitadas porque elas tomam a forma como um argumento e, portanto, não podem ser rastreadas. Presumo que seja por isso que elas são inerentemente mais lentas? Acho que se qualquer uma delas aparecesse dentro do escopo de outra função jitada, elas poderiam ser rastreadas e aceleradas?
Existe algo melhor que eu possa fazer ou estou forçando os limites do que é possível jax
?