AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • 主页
  • 系统&网络
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • 主页
  • 系统&网络
    • 最新
    • 热门
    • 标签
  • Ubuntu
    • 最新
    • 热门
    • 标签
  • Unix
    • 最新
    • 标签
  • DBA
    • 最新
    • 标签
  • Computer
    • 最新
    • 标签
  • Coding
    • 最新
    • 标签
主页 / coding / 问题 / 77439217
Accepted
user1168149
user1168149
Asked: 2023-11-07 22:51:43 +0800 CST2023-11-07 22:51:43 +0800 CST 2023-11-07 22:51:43 +0800 CST

JAX @jit 用于嵌套类方法

  • 772

我正在尝试与嵌套函数一起使用@jit,但遇到问题。我有一个类用 methodOne接收另一个类。我想将此方法称为 jitted from . 我认为我遵循了 JAX 的常见问题解答“如何将 jit 与方法一起使用?” 部分。 https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods 但是,我遇到了一个错误,指出 . 有人能告诉我如何解决这个问题吗?PlantfuncfuncOneTypeError: One.__init__() got multiple values for argument 'plant'

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util

class One:
    def __init__(self, plant,x):
        self.plant = plant
        self.x = x
    
    @jit
    def call_plant_func(self,y):
        out = self.plant.func(y) + self.x
        return out
    
    def _tree_flatten(self):
        children = (self.x,)  # arrays / dynamic values
        aux_data = {'plant':self.plant}  # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        import pdb; pdb.set_trace();
        return cls(*children, **aux_data)
        
tree_util.register_pytree_node(One,
                               One._tree_flatten,
                               One._tree_unflatten)    
    
class Plant:
    def __init__(self, z,kk):
        self.z =z
    
    @jit
    def func(self,y):
        y = y + self.z
        return y
    
    def _tree_flatten(self):
        children = (self.z,)  # arrays / dynamic values
        aux_data = None # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, children):
        return cls(*children)
   
tree_util.register_pytree_node(Plant,
                               Plant._tree_flatten,
                               Plant._tree_unflatten)

plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))

最后一行给出了上述错误。

nested
  • 1 1 个回答
  • 19 Views

1 个回答

  • Voted
  1. Best Answer
    jakevdp
    2023-11-07T23:24:00+08:002023-11-07T23:24:00+08:00

    tree_flatten您在两个类中的和代码中都存在问题tree_unflatten。

    • One._tree_flatten视为plant静态数据,但事实并非如此:它是一个具有非静态元素的 pytree。
    • One._tree_unflatten以错误的顺序实例化One参数,导致您看到的错误
    • Plant.__init__对论证没有任何作用kk。
    • Plant._tree_unflatten缺少aux_data参数,并且无法将kk参数传递给Plant.__init__

    解决这些问题后,您的代码执行时不会出现错误:

    class One:
        def __init__(self, plant,x):
            self.plant = plant
            self.x = x
        
        @jit
        def call_plant_func(self,y):
            out = self.plant.func(y) + self.x
            return out
        
        def _tree_flatten(self):
            children = (self.plant, self.x)
            aux_data = None
            return (children, aux_data)
    
        @classmethod
        def _tree_unflatten(cls, aux_data, children):
            return cls(*children)
            
    tree_util.register_pytree_node(One,
                                   One._tree_flatten,
                                   One._tree_unflatten)    
        
    class Plant:
        def __init__(self, z, kk):
            self.kk = kk
            self.z =z
        
        @jit
        def func(self, y):
            y = y + self.z
            return y
        
        def _tree_flatten(self):
            children = (self.z, self.kk)
            aux_data = None
            return (children, aux_data)
    
        @classmethod
        def _tree_unflatten(cls, aux_data, children):
            return cls(*children)
       
    tree_util.register_pytree_node(Plant,
                                   Plant._tree_flatten,
                                   Plant._tree_unflatten)
    
    plant = Plant(5,2)
    one = One(plant,2)
    print(one.call_plant_func(10))
    
    • 0

相关问题

Sidebar

Stats

  • 问题 205573
  • 回答 270741
  • 最佳答案 135370
  • 用户 68524
  • 热门
  • 回答
  • Marko Smith

    使用 <font color="#xxx"> 突出显示 html 中的代码

    • 2 个回答
  • Marko Smith

    为什么在传递 {} 时重载解析更喜欢 std::nullptr_t 而不是类?

    • 1 个回答
  • Marko Smith

    您可以使用花括号初始化列表作为(默认)模板参数吗?

    • 2 个回答
  • Marko Smith

    为什么列表推导式在内部创建一个函数?

    • 1 个回答
  • Marko Smith

    我正在尝试仅使用海龟随机和数学模块来制作吃豆人游戏

    • 1 个回答
  • Marko Smith

    java.lang.NoSuchMethodError: 'void org.openqa.selenium.remote.http.ClientConfig.<init>(java.net.URI, java.time.Duration, java.time.Duratio

    • 3 个回答
  • Marko Smith

    为什么 'char -> int' 是提升,而 'char -> Short' 是转换(但不是提升)?

    • 4 个回答
  • Marko Smith

    为什么库中不调用全局变量的构造函数?

    • 1 个回答
  • Marko Smith

    std::common_reference_with 在元组上的行为不一致。哪个是对的?

    • 1 个回答
  • Marko Smith

    C++17 中 std::byte 只能按位运算?

    • 1 个回答
  • Martin Hope
    fbrereto 为什么在传递 {} 时重载解析更喜欢 std::nullptr_t 而不是类? 2023-12-21 00:31:04 +0800 CST
  • Martin Hope
    比尔盖子 您可以使用花括号初始化列表作为(默认)模板参数吗? 2023-12-17 10:02:06 +0800 CST
  • Martin Hope
    Amir reza Riahi 为什么列表推导式在内部创建一个函数? 2023-11-16 20:53:19 +0800 CST
  • Martin Hope
    Michael A fmt 格式 %H:%M:%S 不带小数 2023-11-11 01:13:05 +0800 CST
  • Martin Hope
    God I Hate Python C++20 的 std::views::filter 未正确过滤视图 2023-08-27 18:40:35 +0800 CST
  • Martin Hope
    LiDa Cute 为什么 'char -> int' 是提升,而 'char -> Short' 是转换(但不是提升)? 2023-08-24 20:46:59 +0800 CST
  • Martin Hope
    jabaa 为什么库中不调用全局变量的构造函数? 2023-08-18 07:15:20 +0800 CST
  • Martin Hope
    Panagiotis Syskakis std::common_reference_with 在元组上的行为不一致。哪个是对的? 2023-08-17 21:24:06 +0800 CST
  • Martin Hope
    Alex Guteniev 为什么编译器在这里错过矢量化? 2023-08-17 18:58:07 +0800 CST
  • Martin Hope
    wimalopaan C++17 中 std::byte 只能按位运算? 2023-08-17 17:13:58 +0800 CST

热门标签

python javascript c++ c# java typescript sql reactjs html

Explore

  • 主页
  • 问题
    • 最新
    • 热门
  • 标签
  • 帮助

Footer

AskOverflow.Dev

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve