基于pytorch用Attentionunet构建的神经网络无法收敛

基于pytorch用Attentionunet构建的器官分割的模型

任务大概是将500个512x512的npy格式的ct图进行器官分割,mask是500个40x512x512经过标记的npy文件(大概意思是要分割出40个器官,分割出的数组由0、1组成)。

不知道为什么训练无法收敛。loss一直在0.99,调整过学习率仍无法收敛,想问问是否是代码导致的这个问题,如果是要如何修改。如果代码没有问题,还有什么原因会导致这种情况,要怎么解决。

这里是代码:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# 定义数据集读取
class CustomDataset(Dataset):
    def __init__(self, data_dir, mode, transform=None):
        self.mode = mode
        self.transform = transform
        self.image_dir = os.path.join(data_dir, 'image')
        self.mask_dir = os.path.join(data_dir, 'mask')
        self.image_list = sorted(os.listdir(self.image_dir))
        self.mask_list = sorted(os.listdir(self.mask_dir))
        assert len(self.image_list) == len(self.mask_list), "Number of images and masks do not match!"

        if self.mode == 'train':
            self.image_list = self.image_list[:400]
            self.mask_list = self.mask_list[:400]
        elif self.mode == 'val':
            self.image_list = self.image_list[400:]
            self.mask_list = self.mask_list[400:]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_list[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_list[idx])
        img = np.load(img_path)
        mask = np.load(mask_path)
        assert img.shape == (512, 512), f"Invalid image shape: {img.shape}"
        assert mask.shape == (40, 512, 512), f"Invalid mask shape: {mask.shape}"

        # 两种归一化方法
        img = (img - np.min(img)) / (np.max(img) - np.min(img))
        # img = img / 65535.0
        mask = mask / 255.0

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()
        return img.unsqueeze(0), mask


data__dir = '/home/siat/Desktop/orgseg/chest/train'  # 文件路径
batch_size = 8

# 读取数据集
train_dataset = CustomDataset(data__dir, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

val_dataset = CustomDataset(data__dir, mode='val')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


# 定义网络

class AttU_Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=40):
        super(AttU_Net, self).__init__()

        # 第一次下采样
        self.conv1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第二次下采样
        self.conv2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第三次下采样
        self.conv3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第四次下采样
        self.conv4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 中间卷积部分
        self.conv5 = conv_block(512, 1024)

        # 第一次上采样
        self.upconv6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.att6 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.conv6 = conv_block(1024, 512)

        # 第二次上采样
        self.upconv7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att7 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.conv7 = conv_block(512, 256)

        # 第三次上采样
        self.upconv8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att8 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.conv8 = conv_block(256, 128)

        # 第四次上采样
        self.upconv9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att9 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.conv9 = conv_block(128, 64)

        # 输出部分
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # 下采样部分
        x1 = self.conv1(x)
        x2 = self.conv2(self.pool1(x1))
        x3 = self.conv3(self.pool2(x2))
        x4 = self.conv4(self.pool3(x3))

        # 中间卷积部分
        x5 = self.conv5(self.pool4(x4))

        # 上采样部分
        x6 = self.upconv6(x5)
        x6 = self.att6(x6, x4)
        x6 = self.conv6(torch.cat([x6, x4], dim=1))
        x7 = self.upconv7(x6)
        x7 = self.att7(x7, x3)
        x7 = self.conv7(torch.cat([x7, x3], dim=1))

        x8 = self.upconv8(x7)
        x8 = self.att8(x8, x2)
        x8 = self.conv8(torch.cat([x8, x2], dim=1))

        x9 = self.upconv9(x8)
        x9 = self.att9(x9, x1)
        x9 = self.conv9(torch.cat([x9, x1], dim=1))

        # 输出部分
        out = self.conv10(x9)
        return out


class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        # Global average pooling

        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        # Local average pooling
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g_avg = F.avg_pool2d(g, kernel_size=g.size()[2:])
        g_avg = self.W_g(g_avg)

        x_avg = F.avg_pool2d(x, kernel_size=x.size()[2:])
        x_avg = self.W_x(x_avg)

        psi = self.relu(g_avg + x_avg)
        psi = self.psi(psi)

        return x * psi


class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        return x


# GPU数量显示
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs for training")

# 指定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型,使用两个GPU
model = AttU_Net().to(device)
model = nn.DataParallel(model)


# SoftMax
def soft_max(a):
    # 将最后两个维度展平
    tensor_flat = a.view(-1, 512 * 512)

    # 在展平后的最后一个维度上应用 softmax
    softmax_flat = F.softmax(tensor_flat, dim=1)

    # 将结果重塑回原始的形状
    softmax = softmax_flat.view(a.size(0), a.size(1), 512, 512)

    return softmax


# 定义损失函数
def dice_loss(pred, target, smooth=1.0):
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum(1)
    return 1 - ((2. * intersection + smooth) / (m1.sum(1) + m2.sum(1) + smooth)).mean()


criterion = dice_loss

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# 训练模型
epochs = 20

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    print("-" * 10)

    # 训练阶段
    model.train()
    train_loss = 0.0

    for inputs, targets in tqdm(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播
        outputs = model(inputs)
        outputs_1 = soft_max(outputs)

        # 计算损失
        loss = criterion(outputs_1, targets)

        # 反向传播
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)

        # 有需要时观察梯度参数
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(name, param.grad.data.norm(2))
    train_loss /= len(train_loader.dataset)
    print(f"Train Loss: {train_loss:.4f}")

    # 保存模型参数
    # torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pt")

    # 加载模型参数
    # if epoch < epochs - 1:
    # model.load_state_dict(torch.load(f"model_epoch_{epoch + 1}.pt"))

    # 验证阶段
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for inputs, targets in tqdm(val_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # 前向传播
            outputs = model(inputs)

            outputs_2 = soft_max(outputs)
            # 计算损失
            loss = criterion(outputs_2, targets)

            val_loss += loss.item() * inputs.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Val Loss: {val_loss:.4f}")

# 保存模型参数
torch.save(model.state_dict(), "model.pth")


看起来这是一个比较复杂的模型,但没有看到问题所在的具体原因。根据您提供的信息,目前还不清楚训练无法收敛的原因。

您可以尝试以下步骤来进一步分析问题并调试您的代码:

检查数据加载和预处理是否正确。您可以使用train_loader和val_loader来可视化数据集的样本,检查是否正确加载和预处理。

检查模型结构是否正确。您可以使用print(model)来查看模型的结构是否正确,并检查是否与您的预期一致。

检查损失函数是否正确。您可以查看是否使用了正确的损失函数,并检查损失值是否合理。

检查优化器是否正确。您可以查看是否使用了正确的优化器,并检查学习率是否合理。

尝试减少模型复杂度。这是因为模型过于复杂可能会导致过拟合和训练不收敛。您可以尝试减少模型的层数或减少每个卷积层的通道数。

尝试使用预训练模型。您可以尝试使用预训练模型作为初始模型,然后微调模型进行器官分割任务。

尝试使用更小的batch size。较大的batch size可能会导致过拟合和训练不收敛。

希望这些提示能对您有所帮助。如果问题仍然存在,请提供更多细节,例如损失函数、优化器、学习率等。