AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / coding / Perguntas / 79551198
Accepted
jworrell
jworrell
Asked: 2025-04-03 01:23:05 +0800 CST2025-04-03 01:23:05 +0800 CST 2025-04-03 01:23:05 +0800 CST

Flax nnx / jax: tree.map para camadas de tamanho incongruente

  • 772

Estou tentando descobrir como usar nnx.split_rngs. Alguém pode dar uma versão do código abaixo que usa nnx.split_rngs com jax.tree.map para produzir um número arbitrário de camadas Lineares com diferentes out_features?

import jax
from flax import nnx
from functools import partial

if __name__ == '__main__':

    session_sizes = {
        'a':2,
        'b':3,
        'c':4,
        'd':5,
        'e':6,
    }
    dz = 2

    rngs = nnx.Rngs(0)
    
    my_linear = partial(
        nnx.Linear,
        use_bias = False,
        in_features = dz,
        rngs=rngs )
    
    def my_linear_wrapper(a):
        return my_linear( out_features=a )

    q_s = jax.tree.map(my_linear_wrapper, session_sizes)

    for k in session_sizes.keys():
        print(q_s[k].kernel)

Então, neste caso, precisaríamos de uma árvore de camadas que levaria nossos 2 in_features em espaços de 2, ..., 6 out_features.

A função my_linear_wrapper é uma espécie de solução alternativa para a solução original que tínhamos em mente, que é mapear da mesma forma que estamos fazendo, mas em vez disso usar (algo como) o decorador de função @nnx.split_rngs.

Existe uma maneira de usar nnx.split_rngs em my_linear para mapear o argumento rng para nnx.Linear?

python
  • 1 1 respostas
  • 34 Views

1 respostas

  • Voted
  1. Best Answer
    Cristian Garcia
    2025-04-03T03:22:51+08:002025-04-03T03:22:51+08:00

    split_rngsé mais útil quando você vai passar o Rngspor uma transformação como vmap, aqui você quer produzir Módulos de tamanho variável, então a solução atual é o caminho a seguir. Por causa de como partialfunciona, você pode simplificar isso para:

    din = 2
    rngs = nnx.Rngs(0)
    
    my_linear = functools.partial(
      nnx.Linear, din, use_bias=False, rngs=rngs
    )
    
    q_s = jax.tree.map(my_linear, session_sizes)
    
    for k in session_sizes.keys():
      print(q_s[k].kernel)
    
    • 0

relate perguntas

  • Como divido o loop for em 3 quadros de dados individuais?

  • Como verificar se todas as colunas flutuantes em um Pandas DataFrame são aproximadamente iguais ou próximas

  • Como funciona o "load_dataset", já que não está detectando arquivos de exemplo?

  • Por que a comparação de string pandas.eval() retorna False

  • Python tkinter/ ttkboostrap dateentry não funciona quando no estado somente leitura

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    Reformatar números, inserindo separadores em posições fixas

    • 6 respostas
  • Marko Smith

    Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não?

    • 2 respostas
  • Marko Smith

    Problema com extensão desinstalada automaticamente do VScode (tema Material)

    • 2 respostas
  • Marko Smith

    Vue 3: Erro na criação "Identificador esperado, mas encontrado 'import'" [duplicado]

    • 1 respostas
  • Marko Smith

    Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores?

    • 1 respostas
  • Marko Smith

    Como faço para corrigir um erro MODULE_NOT_FOUND para um módulo que não importei manualmente?

    • 6 respostas
  • Marko Smith

    `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso?

    • 3 respostas
  • Marko Smith

    Um programa vazio que não faz nada em C++ precisa de um heap de 204 KB, mas não em C

    • 1 respostas
  • Marko Smith

    PowerBI atualmente quebrado com BigQuery: problema de driver Simba com atualização do Windows

    • 2 respostas
  • Marko Smith

    AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos

    • 1 respostas
  • Martin Hope
    Fantastic Mr Fox Somente o tipo copiável não é aceito na implementação std::vector do MSVC 2025-04-23 06:40:49 +0800 CST
  • Martin Hope
    Howard Hinnant Encontre o próximo dia da semana usando o cronógrafo 2025-04-21 08:30:25 +0800 CST
  • Martin Hope
    Fedor O inicializador de membro do construtor pode incluir a inicialização de outro membro? 2025-04-15 01:01:44 +0800 CST
  • Martin Hope
    Petr Filipský Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não? 2025-03-23 21:39:40 +0800 CST
  • Martin Hope
    Catskul O C++20 mudou para permitir a conversão de `type(&)[N]` de matriz de limites conhecidos para `type(&)[]` de matriz de limites desconhecidos? 2025-03-04 06:57:53 +0800 CST
  • Martin Hope
    Stefan Pochmann Como/por que {2,3,10} e {x,3,10} com x=2 são ordenados de forma diferente? 2025-01-13 23:24:07 +0800 CST
  • Martin Hope
    Chad Feller O ponto e vírgula agora é opcional em condicionais bash com [[ .. ]] na versão 5.2? 2024-10-21 05:50:33 +0800 CST
  • Martin Hope
    Wrench Por que um traço duplo (--) faz com que esta cláusula MariaDB seja avaliada como verdadeira? 2024-05-05 13:37:20 +0800 CST
  • Martin Hope
    Waket Zheng Por que `dict(id=1, **{'id': 2})` às vezes gera `KeyError: 'id'` em vez de um TypeError? 2024-05-04 14:19:19 +0800 CST
  • Martin Hope
    user924 AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos 2024-03-20 03:12:31 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve