Estou fazendo algumas simulações de física usando jax, e isso envolve uma função chamada Hamiltoniano, definida da seguinte forma:
# Constructing the Hamiltonian
@partial(jit, static_argnames=['n', 'omega'])
def hamiltonian(n: int, omega: float):
"""Construct the Hamiltonian for the system."""
H = omega * create(n) @ annhilate(n)
return H
e então uma função maior def solve_diff(n, omega, kappa, alpha0):
que é definida da seguinte forma:
@partial(jit, static_argnames=['n', 'omega'])
def solve_diff(n, omega, kappa, alpha0):
# Some functionality that uses kappa and alpha0
H = hamiltonian(n, omega)
# returns an expectation value
Quando tento calcular o gradiente desta função usando jax.grad
n = 16
omega = 1.0
kappa = 0.1
alpha0 = 1.0
# Compute gradients with respect to omega, kappa, and alpha0
grad_population = grad(solve_diff, argnums=(1, 2, 3))
grads = grad_population(n, omega, kappa, alpha0)
print(f"Gradient w.r.t. omega: {grads[0]}")
print(f"Gradient w.r.t. kappa: {grads[1]}")
print(f"Gradient w.r.t. alpha0: {grads[2]}")
ele gera o seguinte erro:
ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.ad.JVPTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace> with
primal = 1.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding(). The error was:
TypeError: unhashable type: 'JVPTracer'
Porém, rodando solve_diff(16,1.0,0.1,1.0)
sozinho funciona como esperado.
Agora, se eu remover omega
da lista de variáveis estáticas tanto a hamiltonian
função quanto o solve_diff
, o grad será gerado conforme o esperado.
Isso está me confundindo, porque não sei mais o que se qualifica como variáveis estáticas ou dinâmicas, a partir da definição de que variáveis estáticas não mudam entre chamadas de função, ambas n
e omega
são constantes e, de fato, não devem mudar entre chamadas de função.
A questão fundamental é que você não pode diferenciar em relação a uma variável estática e, se tentar fazer isso, obterá o erro que observou.
Em JAX, o termo "estático" não tem a ver com se a variável é alterada entre chamadas de função. Em vez disso, uma variável estática é uma variável que não participa do rastreamento, que é o mecanismo usado para calcular transformações como
vmap
,grad
,jit
, etc. Quando você diferencia em relação a uma variável, ela não é mais estática porque está participando da transformação autodiff, e tentar tratá-la como estática mais tarde na computação levará a um erro.Para uma discussão sobre transformações, rastreamento e conceitos relacionados, eu começaria com Conceitos-chave do JAX: transformações .