Estou tentando descobrir como usar nnx.split_rngs. Alguém pode dar uma versão do código abaixo que usa nnx.split_rngs com jax.tree.map para produzir um número arbitrário de camadas Lineares com diferentes out_features?
import jax
from flax import nnx
from functools import partial
if __name__ == '__main__':
session_sizes = {
'a':2,
'b':3,
'c':4,
'd':5,
'e':6,
}
dz = 2
rngs = nnx.Rngs(0)
my_linear = partial(
nnx.Linear,
use_bias = False,
in_features = dz,
rngs=rngs )
def my_linear_wrapper(a):
return my_linear( out_features=a )
q_s = jax.tree.map(my_linear_wrapper, session_sizes)
for k in session_sizes.keys():
print(q_s[k].kernel)
Então, neste caso, precisaríamos de uma árvore de camadas que levaria nossos 2 in_features em espaços de 2, ..., 6 out_features.
A função my_linear_wrapper é uma espécie de solução alternativa para a solução original que tínhamos em mente, que é mapear da mesma forma que estamos fazendo, mas em vez disso usar (algo como) o decorador de função @nnx.split_rngs.
Existe uma maneira de usar nnx.split_rngs em my_linear para mapear o argumento rng para nnx.Linear?
split_rngs
é mais útil quando você vai passar oRngs
por uma transformação comovmap
, aqui você quer produzir Módulos de tamanho variável, então a solução atual é o caminho a seguir. Por causa de comopartial
funciona, você pode simplificar isso para: