我正在尝试与嵌套函数一起使用@jit
,但遇到问题。我有一个类用 methodOne
接收另一个类。我想将此方法称为 jitted from . 我认为我遵循了 JAX 的常见问题解答“如何将 jit 与方法一起使用?” 部分。
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods
但是,我遇到了一个错误,指出
. 有人能告诉我如何解决这个问题吗?Plant
func
func
One
TypeError: One.__init__() got multiple values for argument 'plant'
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
最后一行给出了上述错误。