class ImageFolder(Dataset):
def __init__(self, root, transform=None,patch_size=(256,256), split='train',need_file_name = False):
splitdir = Path(root) / split # 相当于osp.join
if not splitdir.is_dir():
raise RuntimeError(f'Invalid directory "{root}"')
splitdir_left = splitdir / "left"
splitdir_right = splitdir / "right"
self.left_list = sorted(glob.glob(os.path.join(splitdir_left,"*")))
self.right_list = sorted(glob.glob(os.path.join(splitdir_right, "*")))
self.patch_size = patch_size
#只保留了ToTensor
self.transform = transform
###for homography 单独裁剪 不传参直接设定
self.homopic_size = 256
self.homopatch_size = 128
self.rho = 45
self.homotransforms = transforms.Compose(
[
# ywz
# transforms.Resize(self.homopic_size),
# #
# transforms.CenterCrop(self.homopic_size),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD),
]
)
########################################
self.need_file_name = need_file_name
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
img: `PIL.Image.Image` or transformed `PIL.Image.Image`.
"""
# img1 = Image.open(self.left_list[index]).convert('RGB')
# img2 = Image.open(self.right_list[index]).convert('RGB')
if os.path.basename(self.left_list[index]) != os.path.basename(self.right_list[index]):
print(self.left_list[index])
raise ValueError("cannot compare pictures.")
##
img1 = cv2.imread(self.left_list[index])
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.imread(self.right_list[index])
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
#random cut for pair
H, W, _ = img1.shape
#randint是闭区间
# print(H)
# print(W)
# print(self.patch_size)
if self.patch_size[0]==H:
startH = 0
startW = 0
else:
startH = random.randint(0,H-self.patch_size[0]-1)
startW = random.randint(0,W-self.patch_size[1]-1)
img1 = img1[startH:(startH + self.patch_size[0]), startW:(startW + self.patch_size[1])]
img2 = img2[startH:(startH + self.patch_size[0]), startW:(startW + self.patch_size[1])]
###
# print(img1.shape) #(512,512,3)
# raise ValueError("stop utils")
H_list = get_H(img1,img2) #可以忽略 这个是传统单应性获取的方法
#for homo 在上述patch基础上再进行缩放和裁剪 返回patch以及相应的corners
homo_img1 = cv2.resize(img1,(self.homopic_size,self.homopic_size))
homo_img2 = cv2.resize(img2, (self.homopic_size, self.homopic_size))
homo_img1 = self.homotransforms(homo_img1)
homo_img2 = self.homotransforms(homo_img2)
homo_img1 = torch.mean(homo_img1, dim=0, keepdim=True) # 转灰度
homo_img2 = torch.mean(homo_img2, dim=0, keepdim=True) # 转灰度
# pick top left corner
if self.homopic_size - self.rho - self.homopatch_size >= self.rho:
x = random.randint(self.rho, self.homopic_size - self.rho - self.homopatch_size)
y = random.randint(self.rho, self.homopic_size - self.rho - self.homopatch_size)
else:
x = 0
y = 0
# print(x,y)
corners = torch.tensor(
[
[x, y],
[x + self.homopatch_size, y],
[x + self.homopatch_size, y + self.homopatch_size],
[x, y + self.homopatch_size],
],dtype=torch.float32
)
homo_img1 = homo_img1[:, y: y + self.homopatch_size, x: x + self.homopatch_size]
homo_img2 = homo_img2[:, y: y + self.homopatch_size, x: x + self.homopatch_size]
################## [homo_img1,homo_img2,corners]
##
if H_list[0]==None:
print(self.left_list[index])
print(self.right_list[index])
#raise ValueError("None!!H_matrix")
# 只有ToTensor
if self.transform:
return self.transform(img1), self.transform(img2) # ,H_list[1],H_list[2],H_list[3]
return img1, img2 # ,H_list[1],H_list[2],H_list[3]
#只有ToTensor
if self.transform:
# return self.transform(img1),self.transform(img2),H_list[0] #,H_list[1],H_list[2],H_list[3]
if self.need_file_name:
return self.transform(img1), self.transform(img2), H_list[0], os.path.basename(self.left_list[index]),homo_img1,homo_img2,corners # ,H_list[1],H_list[2],H_list[3]
else:
return self.transform(img1), self.transform(img2), H_list[0],homo_img1,homo_img2,corners # ,H_list[1],H_list[2],H_list[3]
if self.need_file_name:
return img1, img2, H_list[0],os.path.basename(self.left_list[index]),homo_img1,homo_img2,corners # ,H_list[1],H_list[2],H_list[3]
else:
return img1,img2,H_list[0],homo_img1,homo_img2,corners #,H_list[1],H_list[2],H_list[3]
def __len__(self):
return len(self.left_list)
应该是args.patch-size
,你在parser.argument定义的为--patch-size,所以要用短横杠的,除非把parser.argument中的定义改为下划线的--patch_size