AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • 主页
  • 系统&网络
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • 主页
  • 系统&网络
    • 最新
    • 热门
    • 标签
  • Ubuntu
    • 最新
    • 热门
    • 标签
  • Unix
    • 最新
    • 标签
  • DBA
    • 最新
    • 标签
  • Computer
    • 最新
    • 标签
  • Coding
    • 最新
    • 标签
主页 / coding / 问题 / 77598815
Accepted
P.Jo
P.Jo
Asked: 2023-12-04 18:46:42 +0800 CST2023-12-04 18:46:42 +0800 CST 2023-12-04 18:46:42 +0800 CST

tf-神经网络不工作 - pytorch 可以

  • 772

我创建了一个很小的数据集,其中存在精确的线性关系。代码如下:

import numpy as np

def gen_data(n, k):
    np.random.seed(5711)
    beta = np.random.uniform(0, 1, size=(k, 1))
    print("beta is:", beta)
    X = np.random.normal(size=(n, k))
    y = X.dot(beta).reshape(-1, 1)
    D = np.concatenate([X, y], axis=1)
    return D.astype(np.float32)

现在我已经安装了一个带有 SGD 优化器和 MSE 损失的 pyTorch 神经网络,它在 50 个时期内近似收敛到真实值,学习率为 1e-1

我尝试在张量流中设置完全相同的模型:

import keras.layers
from sklearn.model_selection import train_test_split
from keras.models import Sequential
import tensorflow as tf

n = 10
k = 2
X = gen_data(n, k)
D_train, D_test = train_test_split(X, test_size=0.2)
X_train, y_train = D_train[:,:k], D_train[:,k:]
X_test, y_test = D_test[:,:k], D_test[:,k:]

model = Sequential([keras.layers.Dense(1)])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=1e-1), loss=tf.keras.losses.mean_squared_error)
model.fit(X_train, y_train, batch_size=64, epochs=50)

当我调用 model.get_weights 时,它显示与真实值的显着差异,并且损失仍然不接近于零。我不知道为什么这个模型的性能不如 pytorch 模型。即使您忽略 pytorch 模型,网络也不应该收敛到这个小玩具数据集中的真实值。我在设置模型时犯了什么错误?

编辑:这是我完整的 pytorch 代码进行比较:

import torch
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler, RandomSampler
from torch import nn
from sklearn.model_selection import train_test_split

n = 10
k = 2
device =  "cpu"

class Daten(Dataset):

    def __init__(self, df):
        self.df = df
        self.ycol = df.shape[1] - 1

    def __getitem__(self, index):
        return self.df[index, :self.ycol], self.df[index, self.ycol:]

    def __len__(self):
        return self.df.shape[0]

def split_into(D, batch_size=64, **kwargs):
    D_train, D_test = train_test_split(D, **kwargs)
    df_train, df_test = Daten(D_train), Daten(D_test)
    dl_train, dl_test = DataLoader(df_train, batch_size=batch_size), DataLoader(df_test, batch_size=batch_size)
    return dl_train, dl_test

D = gen_data(n, k)
dl_train, dl_test = split_into(D, test_size=0.2)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(k, 1)
        )

    def forward(self, x):
        ypred = self.linear(x)
        return ypred


model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        print(y.shape)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

epochs = 50
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(dl_train, model, loss_fn, optimizer)
print("Done!")

编辑:

我大幅增加了纪元。epochs=1000 后我们就接近真实值了。因此,我对差异的最佳猜测是 tf 应用了一些非最佳初始化?

python
  • 1 1 个回答
  • 68 Views

1 个回答

  • Voted
  1. Best Answer
    mhenning
    2023-12-04T21:37:47+08:002023-12-04T21:37:47+08:00

    您的lr参数SGD已被弃用:

    警告:absl:lr在 Keras 优化器中已弃用,请使用learning_rate或使用旧版优化器,例如 tf.keras.optimizers.legacy.SGD。

    如果我使用

    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=1e-1), loss=tf.keras.losses.mean_squared_error)
    

    然后我得到了loss: 7.0588e-05(没有偏见loss: 2.0572e-08:)。
    通过我的简单火炬模型,我得到了loss: 5.3355e-05(没有偏见:)loss: 5.3071e-09。

    有趣的是,偏差在这里发挥了负面作用,我认为 X 和 y 之间的关系过于线性,以至于无法使用偏差,但模型无论如何都会尝试它。如果你添加这一行

    y += np.random.rand(*y.shape)*0.2
    

    对于数据创建,那么带有偏差的模型对于 torch 和 TF 的表现会更好,因为 X 和 y 之间的关系存在实际偏差,并且模型可以学习这一点。

    • 1

相关问题

  • 如何将 for 循环拆分为 3 个单独的数据框?

  • 如何检查 Pandas DataFrame 中的所有浮点列是否近似相等或接近

  • “load_dataset”如何工作,因为它没有检测示例文件?

  • 为什么 pandas.eval() 字符串比较返回 False

  • Python tkinter/ ttkboostrap dateentry 在只读状态下不起作用

Sidebar

Stats

  • 问题 205573
  • 回答 270741
  • 最佳答案 135370
  • 用户 68524
  • 热门
  • 回答
  • Marko Smith

    使用 <font color="#xxx"> 突出显示 html 中的代码

    • 2 个回答
  • Marko Smith

    为什么在传递 {} 时重载解析更喜欢 std::nullptr_t 而不是类?

    • 1 个回答
  • Marko Smith

    您可以使用花括号初始化列表作为(默认)模板参数吗?

    • 2 个回答
  • Marko Smith

    为什么列表推导式在内部创建一个函数?

    • 1 个回答
  • Marko Smith

    我正在尝试仅使用海龟随机和数学模块来制作吃豆人游戏

    • 1 个回答
  • Marko Smith

    java.lang.NoSuchMethodError: 'void org.openqa.selenium.remote.http.ClientConfig.<init>(java.net.URI, java.time.Duration, java.time.Duratio

    • 3 个回答
  • Marko Smith

    为什么 'char -> int' 是提升,而 'char -> Short' 是转换(但不是提升)?

    • 4 个回答
  • Marko Smith

    为什么库中不调用全局变量的构造函数?

    • 1 个回答
  • Marko Smith

    std::common_reference_with 在元组上的行为不一致。哪个是对的?

    • 1 个回答
  • Marko Smith

    C++17 中 std::byte 只能按位运算?

    • 1 个回答
  • Martin Hope
    fbrereto 为什么在传递 {} 时重载解析更喜欢 std::nullptr_t 而不是类? 2023-12-21 00:31:04 +0800 CST
  • Martin Hope
    比尔盖子 您可以使用花括号初始化列表作为(默认)模板参数吗? 2023-12-17 10:02:06 +0800 CST
  • Martin Hope
    Amir reza Riahi 为什么列表推导式在内部创建一个函数? 2023-11-16 20:53:19 +0800 CST
  • Martin Hope
    Michael A fmt 格式 %H:%M:%S 不带小数 2023-11-11 01:13:05 +0800 CST
  • Martin Hope
    God I Hate Python C++20 的 std::views::filter 未正确过滤视图 2023-08-27 18:40:35 +0800 CST
  • Martin Hope
    LiDa Cute 为什么 'char -> int' 是提升,而 'char -> Short' 是转换(但不是提升)? 2023-08-24 20:46:59 +0800 CST
  • Martin Hope
    jabaa 为什么库中不调用全局变量的构造函数? 2023-08-18 07:15:20 +0800 CST
  • Martin Hope
    Panagiotis Syskakis std::common_reference_with 在元组上的行为不一致。哪个是对的? 2023-08-17 21:24:06 +0800 CST
  • Martin Hope
    Alex Guteniev 为什么编译器在这里错过矢量化? 2023-08-17 18:58:07 +0800 CST
  • Martin Hope
    wimalopaan C++17 中 std::byte 只能按位运算? 2023-08-17 17:13:58 +0800 CST

热门标签

python javascript c++ c# java typescript sql reactjs html

Explore

  • 主页
  • 问题
    • 最新
    • 热门
  • 标签
  • 帮助

Footer

AskOverflow.Dev

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve