更改video swin transformer的损失函数

想要更换video swin transformer的损失函数应该需要更改哪里


import torch.nn as nn
import torch.optim as optim

# 定义自定义损失函数
def custom_loss_function(output, target):
    loss = nn.MSELoss()
    return loss(output, target)

# 定义模型、优化器和损失函数
model = VideoSwimTransformer()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = custom_loss_function

# 训练模型并计算损失
for epoch in range(num_epochs):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()