AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / coding / Perguntas / 79256001
Accepted
Ben
Ben
Asked: 2024-12-06 03:29:25 +0800 CST2024-12-06 03:29:25 +0800 CST 2024-12-06 03:29:25 +0800 CST

Rotinas de criação de arrays personalizadas de forma eficiente em JAX

  • 772

Ainda estou pegando o jeito das melhores práticas em jax. Minha pergunta geral é a seguinte:

Quais são as melhores práticas para a implementação de rotinas de criação de matrizes personalizadas em jax?

Por exemplo, eu quero implementar uma função que cria uma matriz com zeros em todos os lugares, exceto com uns em uma coluna dada. Eu fui para isso (Jupyter notebook):

import numpy as np
import jax.numpy as jnp

def ones_at_col(shape_mat, idx):
    idxs = jnp.arange(shape_mat[1])[None,:]
    mat = jnp.where(idx==idxs, 1, 0)
    mat = jnp.repeat(mat, shape_mat[0], axis=0)
    return mat

shape_mat = (5,10)

print(ones_at_col(shape_mat, 5))

%timeit np.zeros(shape_mat)

%timeit jnp.zeros(shape_mat)

%timeit ones_at_col(shape_mat, 5)

A saída é

[[0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]]
127 ns ± 0.717 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
31.3 µs ± 331 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
123 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Minha função é um fator de 4 mais lenta que a jnp.zeros()rotina, o que não é tão ruim. Isso me diz que o que estou fazendo não é loucura.

Mas então ambas jaxas rotinas são muito mais lentas do que as numpyrotinas equivalentes. Essas funções não podem ser jitadas porque elas tomam a forma como um argumento e, portanto, não podem ser rastreadas. Presumo que seja por isso que elas são inerentemente mais lentas? Acho que se qualquer uma delas aparecesse dentro do escopo de outra função jitada, elas poderiam ser rastreadas e aceleradas?

Existe algo melhor que eu possa fazer ou estou forçando os limites do que é possível jax?

numpy
  • 1 1 respostas
  • 20 Views

1 respostas

  • Voted
  1. Best Answer
    jakevdp
    2024-12-06T04:08:00+08:002024-12-06T04:08:00+08:00

    A melhor maneira de fazer isso é provavelmente algo assim:

    mat = jnp.zeros(shape_mat).at[:, 5].set(1)
    

    Em relação a comparações de tempo com NumPy, a leitura relevante é JAX FAQ: o JAX é mais rápido que o NumPy? O resumo é que para este caso em particular (criação de um array simples) você não esperaria que o JAX correspondesse ao NumPy em termos de desempenho, devido à sobrecarga de despacho por operação do JAX.

    Se você deseja um desempenho mais rápido em JAX, você deve sempre usar jax.jitpara compilar sua função just-in-time. Por exemplo, esta versão da função deve ser bem ótima (embora, novamente, não tão rápida quanto NumPy pelos motivos discutidos no link FAQ):

    @partial(jax.jit, static_argnames=['shape_mat', 'idx'])
    def ones_at_col(shape_mat, idx):
      return jnp.zeros(shape_mat).at[:, idx].set(1)
    

    Você pode deixar o idxvalor não estático se for chamar essa função várias vezes com valores de índice diferentes e, se estiver criando essas matrizes dentro de outra função, basta colocar o código em linha e compilar JIT essa função externa.

    Outra observação: seus microbenchmarks podem não estar medindo o que você acha que estão medindo: para dicas sobre isso, veja JAX FAQ: benchmarking JAX code . Em particular, tenha cuidado com o tempo de compilação e os efeitos de despacho assíncrono.

    • 1

relate perguntas

  • Como classificar o tensor em "lote" por valor de chave específico?

  • Aviso de descontinuação do notebook Jupyter ao encontrar a raiz do determinante de uma matriz

  • Como você concatena matrizes internas do tensor ao longo do eixo?

  • Digite regra de promoção para i4 e S8 no documento numpy

  • Transmitindo uma matriz numpy em uma matriz de tamanho maior usando uma matriz de índice

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    Vue 3: Erro na criação "Identificador esperado, mas encontrado 'import'" [duplicado]

    • 1 respostas
  • Marko Smith

    Por que esse código Java simples e pequeno roda 30x mais rápido em todas as JVMs Graal, mas não em nenhuma JVM Oracle?

    • 1 respostas
  • Marko Smith

    Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores?

    • 1 respostas
  • Marko Smith

    Como faço para corrigir um erro MODULE_NOT_FOUND para um módulo que não importei manualmente?

    • 6 respostas
  • Marko Smith

    `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso?

    • 3 respostas
  • Marko Smith

    Quando devo usar um std::inplace_vector em vez de um std::vector?

    • 3 respostas
  • Marko Smith

    Um programa vazio que não faz nada em C++ precisa de um heap de 204 KB, mas não em C

    • 1 respostas
  • Marko Smith

    PowerBI atualmente quebrado com BigQuery: problema de driver Simba com atualização do Windows

    • 2 respostas
  • Marko Smith

    AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos

    • 1 respostas
  • Marko Smith

    Estou tentando fazer o jogo pacman usando apenas o módulo Turtle Random e Math

    • 1 respostas
  • Martin Hope
    Aleksandr Dubinsky Por que a correspondência de padrões com o switch no InetAddress falha com 'não cobre todos os valores de entrada possíveis'? 2024-12-23 06:56:21 +0800 CST
  • Martin Hope
    Phillip Borge Por que esse código Java simples e pequeno roda 30x mais rápido em todas as JVMs Graal, mas não em nenhuma JVM Oracle? 2024-12-12 20:46:46 +0800 CST
  • Martin Hope
    Oodini Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores? 2024-12-12 06:27:11 +0800 CST
  • Martin Hope
    sleeptightAnsiC `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso? 2024-11-09 07:18:53 +0800 CST
  • Martin Hope
    The Mad Gamer Quando devo usar um std::inplace_vector em vez de um std::vector? 2024-10-29 23:01:00 +0800 CST
  • Martin Hope
    Chad Feller O ponto e vírgula agora é opcional em condicionais bash com [[ .. ]] na versão 5.2? 2024-10-21 05:50:33 +0800 CST
  • Martin Hope
    Wrench Por que um traço duplo (--) faz com que esta cláusula MariaDB seja avaliada como verdadeira? 2024-05-05 13:37:20 +0800 CST
  • Martin Hope
    Waket Zheng Por que `dict(id=1, **{'id': 2})` às vezes gera `KeyError: 'id'` em vez de um TypeError? 2024-05-04 14:19:19 +0800 CST
  • Martin Hope
    user924 AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos 2024-03-20 03:12:31 +0800 CST
  • Martin Hope
    MarkB Por que o GCC gera código que executa condicionalmente uma implementação SIMD? 2024-02-17 06:17:14 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve