Ainda estou pegando o jeito das melhores práticas em jax
. Minha pergunta geral é a seguinte:
Quais são as melhores práticas para a implementação de rotinas de criação de matrizes personalizadas em jax
?
Por exemplo, eu quero implementar uma função que cria uma matriz com zeros em todos os lugares, exceto com uns em uma coluna dada. Eu fui para isso (Jupyter notebook):
import numpy as np
import jax.numpy as jnp
def ones_at_col(shape_mat, idx):
idxs = jnp.arange(shape_mat[1])[None,:]
mat = jnp.where(idx==idxs, 1, 0)
mat = jnp.repeat(mat, shape_mat[0], axis=0)
return mat
shape_mat = (5,10)
print(ones_at_col(shape_mat, 5))
%timeit np.zeros(shape_mat)
%timeit jnp.zeros(shape_mat)
%timeit ones_at_col(shape_mat, 5)
A saída é
[[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]]
127 ns ± 0.717 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
31.3 µs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
123 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Minha função é um fator de 4 mais lenta que a jnp.zeros()
rotina, o que não é tão ruim. Isso me diz que o que estou fazendo não é loucura.
Mas então ambas jax
as rotinas são muito mais lentas do que as numpy
rotinas equivalentes. Essas funções não podem ser jitadas porque elas tomam a forma como um argumento e, portanto, não podem ser rastreadas. Presumo que seja por isso que elas são inerentemente mais lentas? Acho que se qualquer uma delas aparecesse dentro do escopo de outra função jitada, elas poderiam ser rastreadas e aceleradas?
Existe algo melhor que eu possa fazer ou estou forçando os limites do que é possível jax
?
A melhor maneira de fazer isso é provavelmente algo assim:
Em relação a comparações de tempo com NumPy, a leitura relevante é JAX FAQ: o JAX é mais rápido que o NumPy? O resumo é que para este caso em particular (criação de um array simples) você não esperaria que o JAX correspondesse ao NumPy em termos de desempenho, devido à sobrecarga de despacho por operação do JAX.
Se você deseja um desempenho mais rápido em JAX, você deve sempre usar
jax.jit
para compilar sua função just-in-time. Por exemplo, esta versão da função deve ser bem ótima (embora, novamente, não tão rápida quanto NumPy pelos motivos discutidos no link FAQ):Você pode deixar o
idx
valor não estático se for chamar essa função várias vezes com valores de índice diferentes e, se estiver criando essas matrizes dentro de outra função, basta colocar o código em linha e compilar JIT essa função externa.Outra observação: seus microbenchmarks podem não estar medindo o que você acha que estão medindo: para dicas sobre isso, veja JAX FAQ: benchmarking JAX code . Em particular, tenha cuidado com o tempo de compilação e os efeitos de despacho assíncrono.