我的loss感觉没写错呢!
with torch.cuda.amp.autocast():
#train D
fake_horse = gen_H(zebra)
D_H_real = disc_H(horse)
D_H_fake = disc_H(fake_horse.detach())
# H_reals += D_H_real.mean().item()
# H_fakes += D_H_fake.mean().item()
D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
D_H_loss = D_H_real_loss + D_H_fake_loss
fake_zebra = gen_Z(horse)
D_Z_real = disc_Z(zebra)
D_Z_fake = disc_Z(fake_zebra.detach())
D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
D_Z_loss = D_Z_real_loss + D_Z_fake_loss
D_loss = (D_H_loss + D_Z_loss)/2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
#train G
with torch.cuda.amp.autocast():
# adversarial loss for both generators
D_H_fake = disc_H(fake_horse)
D_Z_fake = disc_Z(fake_zebra)
loss_G_H = mse(D_H_fake, torch.zeros_like(D_H_fake))
loss_G_Z = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
# cycle loss
cycle_zebra = gen_Z(fake_horse)
cycle_horse = gen_H(fake_zebra)
cycle_zebra_loss = l1(zebra, cycle_zebra)
cycle_horse_loss = l1(horse, cycle_horse)
# identity loss (remove these for efficiency if you set lambda_identity=0)
identity_zebra = gen_Z(zebra)
identity_horse = gen_H(horse)
identity_zebra_loss = l1(zebra, identity_zebra)
identity_horse_loss = l1(horse, identity_horse)
# add all togethor
G_loss = (
loss_G_Z
+ loss_G_H
+ cycle_zebra_loss * config.LAMBDA_CYCLE
+ cycle_horse_loss * config.LAMBDA_CYCLE
+ identity_horse_loss * config.LAMBDA_IDENTITY
+ identity_zebra_loss * config.LAMBDA_IDENTITY
)
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
no pain no gan