我正在尝试与嵌套函数一起使用@jit
,但遇到问题。我有一个类用 methodOne
接收另一个类。我想将此方法称为 jitted from . 我认为我遵循了 JAX 的常见问题解答“如何将 jit 与方法一起使用?” 部分。
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods
但是,我遇到了一个错误,指出
. 有人能告诉我如何解决这个问题吗?Plant
func
func
One
TypeError: One.__init__() got multiple values for argument 'plant'
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
最后一行给出了上述错误。
tree_flatten
您在两个类中的和代码中都存在问题tree_unflatten
。One._tree_flatten
视为plant
静态数据,但事实并非如此:它是一个具有非静态元素的 pytree。One._tree_unflatten
以错误的顺序实例化One
参数,导致您看到的错误Plant.__init__
对论证没有任何作用kk
。Plant._tree_unflatten
缺少aux_data
参数,并且无法将kk
参数传递给Plant.__init__
解决这些问题后,您的代码执行时不会出现错误: