我有彩色图片风格迁移的代码,可是不会改为灰度图片的,希望大佬可以帮帮忙。
rgb三通道数字一样就是灰度图啦。如有帮助请采纳
你好,我还是不知道该如何修改,这是代码出现的错误
下面是代码
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy
import cv2
import os
import argparse
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
opts = argparse.Namespace
opts.CONTENT_PATH = "Change\content\timeBinary.png"
opts.STYLE_PATH = "Change\girl"
opts.SAVE_PATH = "Change\result"
opts.use_cuda = True # 使用cpu训练模型
opts.dtype = torch.cuda.FloatTensor # 数据类型
opts.imsize = 128 # 批处理图像大小
use_cuda = opts.use_cuda # 使用cpu训练模型
dtype = opts.dtype # 数据类型
imsize = opts.imsize # 批处理图像大小
loader = transforms.Compose([
transforms.Scale([imsize, imsize]), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
def image_loader(image_name):
image = Image.open(image_name)
image = Variable(loader(image))
# fake batch dimension required to fit network's input dimensions
image = image.unsqueeze(0)
return image
unloader = transforms.ToPILImage() # reconvert into PIL image
def imshow(tensor, title=None):
image = tensor.clone().cpu() # we clone the tensor to not do changes on it
image = image.view(3, imsize, imsize) # remove the fake batch dimension
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
def image_unloader(tensor):
image = tensor.clone().cpu() # we clone the tensor to not do changes on it
image = image.view(3, imsize, imsize) # remove the fake batch dimension
image = unloader(image)
return image
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
self.target = target.detach() * weight
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, target, weight):
super(StyleLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.gram = GramMatrix()
self.criterion = nn.MSELoss()
def forward(self, input):
self.output = input.clone()
self.G = self.gram(input)
self.G.mul_(self.weight)
self.loss = self.criterion(self.G, self.target)
return self.output
def backward(self, retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
cnn = models.vgg19(pretrained=True).features
if use_cuda:
cnn = cnn.cuda()
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(cnn, style_img, content_img,
style_weight=1000, content_weight=1,
content_layers=content_layers_default,
style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
# just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = []
model = nn.Sequential() # the new Sequential module network
gram = GramMatrix() # we need a gram module in order to compute style targets
# move these modules to the GPU if possible:
if use_cuda:
model = model.cuda()
gram = gram.cuda()
# 遍历vgg19各层网络,根据我们前面设置的网络层次计算内容损失和风格损失
i = 1
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).clone()
# style_img经过提取后的特征展开为向量,然后构建gram矩阵
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
model.add_module(name, layer)
i += 1
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
model.add_module(name, layer) # ***
return model, style_losses, content_losses
def get_input_param_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
input_param = nn.Parameter(input_img.data)
optimizer = optim.LBFGS([input_param])
return input_param, optimizer
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
style_weight=1000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn,
style_img, content_img, style_weight,
content_weight)
model.require_grad = False
input_param, optimizer = get_input_param_optimizer(input_img)
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_param.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_param)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.data.cpu().numpy(), content_score.data.cpu().numpy()))
print()
return style_score + content_score
optimizer.step(closure)
# a last correction...
input_param.data.clamp_(0, 1)
return model, style_losses, content_losses, input_param.data
CONTENT_PATH = opts.CONTENT_PATH
STYLE_PATH = opts.STYLE_PATH
SAVE_PATH = opts.SAVE_PATH
if not os.path.exists(SAVE_PATH):
os.mkdir(SAVE_PATH) # 创建
STEP = 300
img_lst = os.listdir(STYLE_PATH)
i = 1
for img_name in img_lst:
print("[%d/%d]" % (i, len(img_lst)))
i = i + 1
img_path = os.path.join(STYLE_PATH, img_name)
style_img = image_loader(img_path).type(dtype) # .squeeze()
# load content
content_img = image_loader(CONTENT_PATH).type(dtype) # .squeeze()
input_img = content_img.clone()
model, style_losses, content_losses, output = run_style_transfer(cnn, content_img, style_img, input_img, STEP)
img = image_unloader(output)
img.save(os.path.join(SAVE_PATH, img_name))