AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / user-10353865

P.Jo's questions

Martin Hope
P.Jo
Asked: 2023-12-04 18:46:42 +0800 CST

rede neural tf não funciona - pytorch funciona

  • 5

Eu criei um pequeno conjunto de dados onde se mantém uma relação linear exata. O código é o seguinte:

import numpy as np

def gen_data(n, k):
    np.random.seed(5711)
    beta = np.random.uniform(0, 1, size=(k, 1))
    print("beta is:", beta)
    X = np.random.normal(size=(n, k))
    y = X.dot(beta).reshape(-1, 1)
    D = np.concatenate([X, y], axis=1)
    return D.astype(np.float32)

Agora eu ajustei uma rede neural pyTorch com otimizador SGD e perda MSE e ela convergiu aproximadamente para os valores verdadeiros dentro de 50 épocas e uma taxa de aprendizado de 1e-1

Tentei configurar exatamente o mesmo modelo no tensorflow:

import keras.layers
from sklearn.model_selection import train_test_split
from keras.models import Sequential
import tensorflow as tf

n = 10
k = 2
X = gen_data(n, k)
D_train, D_test = train_test_split(X, test_size=0.2)
X_train, y_train = D_train[:,:k], D_train[:,k:]
X_test, y_test = D_test[:,:k], D_test[:,k:]

model = Sequential([keras.layers.Dense(1)])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=1e-1), loss=tf.keras.losses.mean_squared_error)
model.fit(X_train, y_train, batch_size=64, epochs=50)

Quando chamo model.get_weights ele mostra diferenças substanciais em relação aos valores verdadeiros e a perda ainda não chega nem perto de zero. Não sei por que este modelo não funciona tão bem quanto o modelo pytorch. Mesmo se você desconsiderar o modelo pytorch, a rede não deveria convergir para os valores verdadeiros neste pequeno conjunto de dados de brinquedo. Qual é o meu erro ao configurar o modelo?

EDIT: E aqui está meu código pytorch completo para comparação:

import torch
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler, RandomSampler
from torch import nn
from sklearn.model_selection import train_test_split

n = 10
k = 2
device =  "cpu"

class Daten(Dataset):

    def __init__(self, df):
        self.df = df
        self.ycol = df.shape[1] - 1

    def __getitem__(self, index):
        return self.df[index, :self.ycol], self.df[index, self.ycol:]

    def __len__(self):
        return self.df.shape[0]

def split_into(D, batch_size=64, **kwargs):
    D_train, D_test = train_test_split(D, **kwargs)
    df_train, df_test = Daten(D_train), Daten(D_test)
    dl_train, dl_test = DataLoader(df_train, batch_size=batch_size), DataLoader(df_test, batch_size=batch_size)
    return dl_train, dl_test

D = gen_data(n, k)
dl_train, dl_test = split_into(D, test_size=0.2)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(k, 1)
        )

    def forward(self, x):
        ypred = self.linear(x)
        return ypred


model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        print(y.shape)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

epochs = 50
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(dl_train, model, loss_fn, optimizer)
print("Done!")

EDITAR:

Aumentei as épocas dramaticamente. Depois de épocas=1000 chegamos perto dos valores verdadeiros. Portanto, meu melhor palpite para a discrepância é que tf aplica alguma inicialização não ideal?

python
  • 1 respostas
  • 68 Views
Martin Hope
P.Jo
Asked: 2023-10-10 21:31:48 +0800 CST

Digite regra de promoção para i4 e S8 no documento numpy

  • 7

Encontrei este exemplo no numpydoc:

np.promote_types('i4', 'S8')
dtype('S11')

Basicamente, não consigo entender o seguinte:

Um i4 ocupa 4 bytes e um S8 ocupa 8 bytes. Então, por que – em termos de memória – preciso de um S11 para acomodar os dois tipos? Eu teria esperado o S8 como resultado.

numpy
  • 2 respostas
  • 27 Views
Martin Hope
P.Jo
Asked: 2023-10-08 21:00:48 +0800 CST

chamando loc com uma matriz booleana contendo NA

  • 8

O documento pandas on loc afirma que pode ser usado com matrizes booleanas, mais especificamente afirma o seguinte:

"As entradas permitidas são: ... Uma matriz booleana (quaisquer valores NA serão tratados como falsos)."

Minha pergunta: como você pode criar uma matriz booleana contendo valores NA? Quero dizer: uma matriz numpy bool não pode conter Nans e se interpretarmos isso de forma mais liberal como afirmando "uma lista contendo valores booleanos e na", então loc lança exceções, por exemplo:

d_test = pd.DataFrame({"id": [1,2,3,5], "q1": [1,4,4,2], "q2": [4,np.nan,9,0]}, index=["a","b","c","d"])
t1 = [True,False,False,np.nan]
d_test.loc[t1] # KeyError
#same with None:
t1 = [True,False,False,None]

Então, minha pergunta: como esta frase deve ser interpretada?

pandas
  • 2 respostas
  • 35 Views
Martin Hope
P.Jo
Asked: 2023-08-30 19:52:11 +0800 CST

A função que retorna uma referência a um int pode ser atribuída a um int

  • 5

Sempre pensei que uma assinatura como

int& val() {...}

indicaria ao chamador que uma referência a um valor int já existente é retornada. No entanto, se eu usar essa função e atribuí-la a uma intvariável (não a um int&), ela será compilada. O resultado é, no entanto, uma cópia do valor real de retorno da função - conforme mostrado pela impressão dos endereços, como pode ser visto na sessão a seguir:

#include <iostream>

using namespace std;

int v = 8;
int& vref()
{
    return v;
}

int main()
{
 /*declared type of x does not match the return type! However, no warnings/messages about any implicit conversion are given here*/
 int x = vref();  
 int& y = vref(); // correct type
 // Now I declare pointers to x,y and v and display their addresses
 int* ptr_x = &x;
 int* ptr_y = &y;
 int* ptr_v = &v;
 cout << "Addr of x: " << ptr_x << endl;
 cout << "Addr of y: " << ptr_y << endl;
 cout << "Addr of v: " << ptr_v << endl;
}

Resultando ye vresidindo no mesmo local de memória (o que eu esperava) - mas não x. Então, eu me pergunto, o que está acontecendo nos bastidores ao atribuirint x = vref();

c++
  • 2 respostas
  • 64 Views

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    Reformatar números, inserindo separadores em posições fixas

    • 6 respostas
  • Marko Smith

    Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não?

    • 2 respostas
  • Marko Smith

    Problema com extensão desinstalada automaticamente do VScode (tema Material)

    • 2 respostas
  • Marko Smith

    Vue 3: Erro na criação "Identificador esperado, mas encontrado 'import'" [duplicado]

    • 1 respostas
  • Marko Smith

    Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores?

    • 1 respostas
  • Marko Smith

    Como faço para corrigir um erro MODULE_NOT_FOUND para um módulo que não importei manualmente?

    • 6 respostas
  • Marko Smith

    `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso?

    • 3 respostas
  • Marko Smith

    Um programa vazio que não faz nada em C++ precisa de um heap de 204 KB, mas não em C

    • 1 respostas
  • Marko Smith

    PowerBI atualmente quebrado com BigQuery: problema de driver Simba com atualização do Windows

    • 2 respostas
  • Marko Smith

    AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos

    • 1 respostas
  • Martin Hope
    Fantastic Mr Fox Somente o tipo copiável não é aceito na implementação std::vector do MSVC 2025-04-23 06:40:49 +0800 CST
  • Martin Hope
    Howard Hinnant Encontre o próximo dia da semana usando o cronógrafo 2025-04-21 08:30:25 +0800 CST
  • Martin Hope
    Fedor O inicializador de membro do construtor pode incluir a inicialização de outro membro? 2025-04-15 01:01:44 +0800 CST
  • Martin Hope
    Petr Filipský Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não? 2025-03-23 21:39:40 +0800 CST
  • Martin Hope
    Catskul O C++20 mudou para permitir a conversão de `type(&)[N]` de matriz de limites conhecidos para `type(&)[]` de matriz de limites desconhecidos? 2025-03-04 06:57:53 +0800 CST
  • Martin Hope
    Stefan Pochmann Como/por que {2,3,10} e {x,3,10} com x=2 são ordenados de forma diferente? 2025-01-13 23:24:07 +0800 CST
  • Martin Hope
    Chad Feller O ponto e vírgula agora é opcional em condicionais bash com [[ .. ]] na versão 5.2? 2024-10-21 05:50:33 +0800 CST
  • Martin Hope
    Wrench Por que um traço duplo (--) faz com que esta cláusula MariaDB seja avaliada como verdadeira? 2024-05-05 13:37:20 +0800 CST
  • Martin Hope
    Waket Zheng Por que `dict(id=1, **{'id': 2})` às vezes gera `KeyError: 'id'` em vez de um TypeError? 2024-05-04 14:19:19 +0800 CST
  • Martin Hope
    user924 AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos 2024-03-20 03:12:31 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve