我得到了以下 U-net 架构导致的问题:
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder1 = self.double_conv(in_channels, 64)
self.encoder2 = self.down(64, 128)
self.encoder3 = self.down(128, 256)
self.encoder4 = self.down(256, 512)
self.bottleneck = self.double_conv(512, 1024)
self.decoder4 = self.up(1024, 512)
self.decoder3 = self.up(512, 256)
self.decoder2 = self.up(256, 128)
self.decoder1 = self.up(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) # SAME convolution/padding
def double_conv(self, in_channels, out_channels): # Convo Block
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def down(self, in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def forward(self, x):
# Encoder
enc1 = self.encoder1(x) # Output: [1, 64, 256, 256]
print("enc1.shape",enc1.shape)
enc2 = self.encoder2(enc1) # Output: [1, 128, 128, 128]
print("enc2.shape",enc2.shape)
enc3 = self.encoder3(enc2) # Output: [1, 256, 64, 64]
print("enc3.shape",enc3.shape)
enc4 = self.encoder4(enc3) # Output: [1, 512, 32, 32]
print("enc4.shape",enc4.shape)
bottleneck_output = self.bottleneck(enc4) # Output: [1, 1024, 32, 32]
print("bottleneck_output",bottleneck_output.shape)
# Decoder
dec4 = self.decoder4(bottleneck_output)#bottleneck_output) # Output: [1, 512, 64, 64]
print(dec4.shape)
dec4 = torch.cat((dec4, enc4), dim=1) # skip connect, Concatenate: [1, 1024, 64, 64]
dec4 = self.double_conv(1024, 512)(dec4) # Corrected input channels to 1024
dec3 = self.decoder3(dec4) # Output: [1, 256, 128, 128]
dec3 = torch.cat((dec3, enc3), dim=1) # Concatenate: [1, 512, 128, 128]
dec3 = self.double_conv(512, 256)(dec3) # Corrected input channels to 512
dec2 = self.decoder2(dec3) # Output: [1, 128, 256, 256]
dec2 = torch.cat((dec2, enc2), dim=1) # Concatenate: [1, 256, 256, 256]
dec2 = self.double_conv(256, 128)(dec2) # Corrected input channels to 256
dec1 = self.decoder1(dec2) # Output: [1, 64, 512, 512]
dec1 = torch.cat((dec1, enc1), dim=1) # Concatenate: [1, 128, 512, 512]
dec1 = self.double_conv(128, 64)(dec1) # Corrected input channels to 128
return self.final_conv(dec1) # Output: [1, 1, 512, 512]```
在通过 main 方法执行时
unet = UNet(in_channels=1, out_channels=1)
sample_input = torch.randn(1, 1, 256, 256)
output = unet(sample_input)
我得到:
enc1.shape torch.Size([1, 64, 256, 256])
enc2.shape torch.Size([1, 128, 128, 128])
enc3.shape torch.Size([1, 256, 64, 64])
enc4.shape torch.Size([1, 512, 32, 32])
bottleneck_output torch.Size([1, 1024, 32, 32])
并出现以下错误:
---> 55 dec4 = self.decoder4(bottleneck_output)
RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 64, 64] to have 1024 channels, but got 512 channels instead
因此问题显然是bottleneck_output
具有 1024 个通道的形状,但decoder4
似乎无法识别它或诸如此类的东西。
我尝试过匹配尺寸和其他东西,比如对齐函数,但到目前为止没有任何效果。打印输出形状也没什么用。谢谢任何提示。