AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / coding / Perguntas / 77598815
Accepted
P.Jo
P.Jo
Asked: 2023-12-04 18:46:42 +0800 CST2023-12-04 18:46:42 +0800 CST 2023-12-04 18:46:42 +0800 CST

rede neural tf não funciona - pytorch funciona

  • 772

Eu criei um pequeno conjunto de dados onde se mantém uma relação linear exata. O código é o seguinte:

import numpy as np

def gen_data(n, k):
    np.random.seed(5711)
    beta = np.random.uniform(0, 1, size=(k, 1))
    print("beta is:", beta)
    X = np.random.normal(size=(n, k))
    y = X.dot(beta).reshape(-1, 1)
    D = np.concatenate([X, y], axis=1)
    return D.astype(np.float32)

Agora eu ajustei uma rede neural pyTorch com otimizador SGD e perda MSE e ela convergiu aproximadamente para os valores verdadeiros dentro de 50 épocas e uma taxa de aprendizado de 1e-1

Tentei configurar exatamente o mesmo modelo no tensorflow:

import keras.layers
from sklearn.model_selection import train_test_split
from keras.models import Sequential
import tensorflow as tf

n = 10
k = 2
X = gen_data(n, k)
D_train, D_test = train_test_split(X, test_size=0.2)
X_train, y_train = D_train[:,:k], D_train[:,k:]
X_test, y_test = D_test[:,:k], D_test[:,k:]

model = Sequential([keras.layers.Dense(1)])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=1e-1), loss=tf.keras.losses.mean_squared_error)
model.fit(X_train, y_train, batch_size=64, epochs=50)

Quando chamo model.get_weights ele mostra diferenças substanciais em relação aos valores verdadeiros e a perda ainda não chega nem perto de zero. Não sei por que este modelo não funciona tão bem quanto o modelo pytorch. Mesmo se você desconsiderar o modelo pytorch, a rede não deveria convergir para os valores verdadeiros neste pequeno conjunto de dados de brinquedo. Qual é o meu erro ao configurar o modelo?

EDIT: E aqui está meu código pytorch completo para comparação:

import torch
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler, RandomSampler
from torch import nn
from sklearn.model_selection import train_test_split

n = 10
k = 2
device =  "cpu"

class Daten(Dataset):

    def __init__(self, df):
        self.df = df
        self.ycol = df.shape[1] - 1

    def __getitem__(self, index):
        return self.df[index, :self.ycol], self.df[index, self.ycol:]

    def __len__(self):
        return self.df.shape[0]

def split_into(D, batch_size=64, **kwargs):
    D_train, D_test = train_test_split(D, **kwargs)
    df_train, df_test = Daten(D_train), Daten(D_test)
    dl_train, dl_test = DataLoader(df_train, batch_size=batch_size), DataLoader(df_test, batch_size=batch_size)
    return dl_train, dl_test

D = gen_data(n, k)
dl_train, dl_test = split_into(D, test_size=0.2)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(k, 1)
        )

    def forward(self, x):
        ypred = self.linear(x)
        return ypred


model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        print(y.shape)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

epochs = 50
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(dl_train, model, loss_fn, optimizer)
print("Done!")

EDITAR:

Aumentei as épocas dramaticamente. Depois de épocas=1000 chegamos perto dos valores verdadeiros. Portanto, meu melhor palpite para a discrepância é que tf aplica alguma inicialização não ideal?

python
  • 1 1 respostas
  • 68 Views

1 respostas

  • Voted
  1. Best Answer
    mhenning
    2023-12-04T21:37:47+08:002023-12-04T21:37:47+08:00

    Seu lrparâmetro para SGDestá obsoleto:

    AVISO:absl: lrestá obsoleto no otimizador Keras, use learning_rateou use o otimizador legado, por exemplo, tf.keras.optimizers.legacy.SGD.

    Se eu usar

    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=1e-1), loss=tf.keras.losses.mean_squared_error)
    

    Então eu entendo loss: 7.0588e-05(sem preconceito loss: 2.0572e-08:).
    Com meu modelo de tocha simples, consegui loss: 5.3355e-05(sem preconceito:) loss: 5.3071e-09.

    É interessante que o viés desempenhe um papel negativo aqui, acho que a relação entre X e y é linear demais para que o viés seja usado, mas o modelo tenta mesmo assim. Se você adicionar a linha

    y += np.random.rand(*y.shape)*0.2
    

    para a criação de dados, então o modelo com viés terá melhor desempenho para tocha e TF, pois há viés real na relação entre X e y e o modelo pode aprender isso.

    • 1

relate perguntas

  • Como divido o loop for em 3 quadros de dados individuais?

  • Como verificar se todas as colunas flutuantes em um Pandas DataFrame são aproximadamente iguais ou próximas

  • Como funciona o "load_dataset", já que não está detectando arquivos de exemplo?

  • Por que a comparação de string pandas.eval() retorna False

  • Python tkinter/ ttkboostrap dateentry não funciona quando no estado somente leitura

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    destaque o código em HTML usando <font color="#xxx">

    • 2 respostas
  • Marko Smith

    Por que a resolução de sobrecarga prefere std::nullptr_t a uma classe ao passar {}?

    • 1 respostas
  • Marko Smith

    Você pode usar uma lista de inicialização com chaves como argumento de modelo (padrão)?

    • 2 respostas
  • Marko Smith

    Por que as compreensões de lista criam uma função internamente?

    • 1 respostas
  • Marko Smith

    Estou tentando fazer o jogo pacman usando apenas o módulo Turtle Random e Math

    • 1 respostas
  • Marko Smith

    java.lang.NoSuchMethodError: 'void org.openqa.selenium.remote.http.ClientConfig.<init>(java.net.URI, java.time.Duration, java.time.Duratio

    • 3 respostas
  • Marko Smith

    Por que 'char -> int' é promoção, mas 'char -> short' é conversão (mas não promoção)?

    • 4 respostas
  • Marko Smith

    Por que o construtor de uma variável global não é chamado em uma biblioteca?

    • 1 respostas
  • Marko Smith

    Comportamento inconsistente de std::common_reference_with em tuplas. Qual é correto?

    • 1 respostas
  • Marko Smith

    Somente operações bit a bit para std::byte em C++ 17?

    • 1 respostas
  • Martin Hope
    fbrereto Por que a resolução de sobrecarga prefere std::nullptr_t a uma classe ao passar {}? 2023-12-21 00:31:04 +0800 CST
  • Martin Hope
    比尔盖子 Você pode usar uma lista de inicialização com chaves como argumento de modelo (padrão)? 2023-12-17 10:02:06 +0800 CST
  • Martin Hope
    Amir reza Riahi Por que as compreensões de lista criam uma função internamente? 2023-11-16 20:53:19 +0800 CST
  • Martin Hope
    Michael A formato fmt %H:%M:%S sem decimais 2023-11-11 01:13:05 +0800 CST
  • Martin Hope
    God I Hate Python std::views::filter do C++20 não filtrando a visualização corretamente 2023-08-27 18:40:35 +0800 CST
  • Martin Hope
    LiDa Cute Por que 'char -> int' é promoção, mas 'char -> short' é conversão (mas não promoção)? 2023-08-24 20:46:59 +0800 CST
  • Martin Hope
    jabaa Por que o construtor de uma variável global não é chamado em uma biblioteca? 2023-08-18 07:15:20 +0800 CST
  • Martin Hope
    Panagiotis Syskakis Comportamento inconsistente de std::common_reference_with em tuplas. Qual é correto? 2023-08-17 21:24:06 +0800 CST
  • Martin Hope
    Alex Guteniev Por que os compiladores perdem a vetorização aqui? 2023-08-17 18:58:07 +0800 CST
  • Martin Hope
    wimalopaan Somente operações bit a bit para std::byte em C++ 17? 2023-08-17 17:13:58 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve