class ComplexBN(torch.nn.Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
super(ComplexBN, self).__init__()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.eps = eps
self.num_features = num_features
# self.batchNorm2dF = torch.nn.BatchNorm2d(num_features,
# affine=affine).to(self.device)
def forward(self, x): # shpae of x : [batch,2,channel,axis1,axis2]
# divide dim=1 to 2 parts -> real and imag
# real/imag = [batch, channel, axis1, axis2]
real = x[:, 0]
imag = x[:, 1]
print('4',x.shape)
print('5',real.shape)
print('6',imag.shape)
realVec = torch.flatten(real)
print('11',realVec.shape)
imagVec = torch.flatten(imag)
re_im_stack = torch.stack((realVec, imagVec), dim=1)
covMat = cov(re_im_stack)
e, v = torch.linalg.eigh(covMat)
covMat_sq2 = torch.mm(torch.mm(v, torch.diag(torch.pow(e, -0.5))),
v.t())
data = torch.stack((realVec - real.mean(), imagVec - imag.mean()),
dim=1).t()
whitenData = torch.mm(covMat_sq2, data)
real_data = whitenData[0, :].reshape(real.shape[0], real.shape[1],
real.shape[2], real.shape[3])
imag_data = whitenData[1, :].reshape(real.shape[0], real.shape[1],
real.shape[2], real.shape[3])
output = torch.stack((real_data, imag_data), dim=1)
print('12', output.shape)
return output
class ComplexConv2D(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(ComplexConv2D, self).__init__()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.padding = padding
# Model components
# define complex conv
self.conv_re = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias).to(self.device)
self.conv_im = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias).to(self.device)
self.weight1 = self.conv_re.weight
self.weight2 = self.conv_im.weight
self.bias1 = self.conv_re.bias
self.bias2 = self.conv_im.bias
def forward(self, x):
paddingF = torch.nn.ZeroPad2d(1)
print('1', x.shape)
r = paddingF(x[:, 0]) # NCHW
i = paddingF(x[:, 1])
print('2', r.shape)
# New 20191102
r[:, :, 0, :], i[:, :, 0, :] = r[:, :, -2, :], i[:, :, -2, :]
r[:, :, -1, :], i[:, :, -1, :] = r[:, :, 1, :], i[:, :, 1, :]
r[:, :, :, 0], i[:, :, :, 0] = r[:, :, :, 2], i[:, :, :, 2]
r[:, :, :, -1], i[:, :, :, -1] = r[:, :, :, 1], i[:, :, :, 1]
print('1234',r.shape)
print('1235',i.shape)
# NEW END
real = self.conv_re(r) - self.conv_im(i)
print('3', real.shape)
imaginary = self.conv_re(i) + self.conv_im(r)
print('9',imaginary.shape)
# stack real and imag part together @ dim=1
output = torch.stack((real, imaginary), dim=1)
print('10',output.shape)
return output
class ComplexReLU(torch.nn.Module):
def __init__(self, inplace=False):
super(ComplexReLU, self).__init__()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.relu_re = torch.nn.ReLU(inplace=inplace).to(self.device)
self.relu_im = torch.nn.ReLU(inplace=inplace).to(self.device)
def forward(self, x):
print('122',x.shape)
output = torch.stack(
(self.relu_re(x[:, 0]), self.relu_im(x[:, 1])), dim=1).to(self.device)
print('123', output.shape)
return output
class ComplexDnCNN(torch.nn.Module):
def __init__(self,
depth=17,
n_channels=64,
image_channels=1,
use_bnorm=True,
kernel_size=3):
super(ComplexDnCNN, self).__init__()
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
# kernel_size = 3
padding = 0
layers = []
# 1. Conv2d and ReLU
layers.append(
ComplexConv2D(in_channels=image_channels,
out_channels=n_channels,
kernel_size=kernel_size,
padding=padding,
bias=True))
layers.append(ComplexReLU(inplace=False))
# 2. 15 * (Conv2d + BN + ReLU)
for _ in range(depth - 2):
layers.append(
ComplexConv2D(in_channels=n_channels,
out_channels=n_channels,
kernel_size=kernel_size,
padding=padding,
bias=False))
'''layers.append(torch.nn.BatchNorm2d(
n_channels, eps=0.0001, momentum=0.95).to(device=self.device))'''
layers.append(ComplexBN(n_channels, eps=0.0001, momentum=0.95))
layers.append(ComplexReLU(inplace=False))
# 3. conv2d
layers.append(
ComplexConv2D(in_channels=n_channels,
out_channels=image_channels,
kernel_size=kernel_size,
padding=padding,
bias=False))
self.dncnn = torch.nn.Sequential(*layers)
def forward(self, x):
y = x
print('20', x.shape)
out = self.dncnn(x)
print('21',out.shape)
return y - out
def findLastCheckpoint(save_dir):
file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
if file_list:
epochs_exist = []
for file_ in file_list:
result = re.findall(".*model_(.*).pth.*", file_)
epochs_exist.append(int(result[0]))
initial_epoch = max(epochs_exist)
else:
initial_epoch = 0
return initial_epoch
def log(*args, **kwargs):
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args,
**kwargs)
if __name__ == '__main__':
print('>>> Building Model')
model = ComplexDnCNN().cuda()
# uncomment to use dataparallel (unstable)
# device_ids = [0, 1]
# model = nn.DataParallel(model, device_ids=device_ids).cuda()
initial_epoch = findLastCheckpoint(save_dir=save_dir)
if initial_epoch > 0:
print('resuming by loading epoch %03d' % initial_epoch)
model.load_state_dict(
torch.load(os.path.join(save_dir,
'model_%03d.pth' % initial_epoch)))
print(">>> Building Model Finished")
model.train() # Enable BN and Dropout
criterion = nn.MSELoss(reduction='sum').cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80, 100, 120],
gamma=0.5) # learning rates
print("Loading Data")
# should be generated by yourself
train_est = train_data + '/' + \
'5dB'+'/'+ 'trainingChannel15' + '.mat'
train_true = train_data + '/' + \
'5dB'+'/'+ 'trueTrainingChannel15' + '.mat'
train_est_mat = sio.loadmat(train_est)
# print(train_est_mat.keys())
x_train = train_est_mat['trainingChannel']
print('111',x_train.shape)
x_train = np.transpose(x_train, [3, 2, 1, 0])
print('>>> training Set setup complete')
# ground truth
train_true_mat = sio.loadmat(train_true)
# y_train = train_true_mat['trueTrainingChannel']
y_train = train_true_mat['trueTrainingChannel']
print('222',y_train.shape)
y_train = np.transpose(y_train, [3, 2, 1, 0])
print('>>> groundTruth Set setup complete')
x_train = torch.from_numpy(x_train).float().reshape(
[x_train.shape[0], x_train.shape[1], 1, x_train.shape[2], x_train.shape[3]])
# x_train = x_train[0:2000, :]
print(x_train.shape)
y_train = torch.from_numpy(y_train).float().reshape(
[y_train.shape[0], y_train.shape[1], 1, y_train.shape[2], y_train.shape[3]])
print(y_train.shape)
for epoch in range(initial_epoch, n_epoch):
DDataset = MyDenoisingDataset(y_train, x_train)
DLoader = DataLoader(dataset=DDataset,
num_workers=0,
drop_last=True,
batch_size=batch_size,
shuffle=True)
epoch_loss = 0
start_time = time.time()
for n_count, batch_yx in enumerate(DLoader):
optimizer.zero_grad()
if cuda:
batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
loss = criterion(model(batch_y), batch_x)
epoch_loss += loss.item()
loss.backward()
optimizer.step()
scheduler.step(epoch) # step to the learning rate in this epoch
if n_count % 10 == 0:
print('%4d %4d / %4d loss = %2.4f\t ' %
(epoch + 1, n_count, x_train.size(0) // batch_size,
loss.item() / batch_size, ))
elapsed_time = time.time() - start_time
log('epoch = %4d , loss = %4.4f , time = %4.2f s' %
(epoch + 1, epoch_loss / n_count, elapsed_time))
torch.save(model.state_dict(),
os.path.join(save_dir, 'model_%03d.pth' % (epoch + 1)))
# torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
以下是模块的输出以及出现的问题
Building Model
Building Model Finished
Loading Data
111 (1, 2, 256, 64)
training Set setup complete
222 (1, 2, 256, 64)
groundTruth Set setup complete
torch.Size([64, 256, 1, 2, 1])
torch.Size([64, 256, 1, 2, 1])
20 torch.Size([8, 256, 1, 2, 1])
1 torch.Size([8, 256, 1, 2, 1])
2 torch.Size([8, 1, 4, 3])
1234 torch.Size([8, 1, 4, 3])
1235 torch.Size([8, 1, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 64, 2, 1])
9 torch.Size([8, 64, 2, 1])
10 torch.Size([8, 2, 64, 2, 1])
4 torch.Size([8, 2, 64, 2, 1])
5 torch.Size([8, 64, 2, 1])
6 torch.Size([8, 64, 2, 1])
11 torch.Size([1024])
12 torch.Size([8, 2, 64, 2, 1])
122 torch.Size([8, 2, 64, 2, 1])
123 torch.Size([8, 2, 64, 2, 1])
1 torch.Size([8, 2, 64, 2, 1])
2 torch.Size([8, 64, 4, 3])
1234 torch.Size([8, 64, 4, 3])
1235 torch.Size([8, 64, 4, 3])
3 torch.Size([8, 1, 2, 1])
9 torch.Size([8, 1, 2, 1])
10 torch.Size([8, 2, 1, 2, 1])
21 torch.Size([8, 2, 1, 2, 1])
Traceback (most recent call last):
File "D:/Doucments/Desktop/去噪/complex-DnCNN-master/train_cDnCNN.py", line 349, in
loss = criterion(model(batch_y), batch_x)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "D:/Doucments/Desktop/去噪/complex-DnCNN-master/train_cDnCNN.py", line 262, in forward
return y - out
RuntimeError: The size of tensor a (256) must match the size of tensor b (2) at non-singleton dimension 1