文章目录
- 一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像
- 二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU
- 三(gpu,init)是将图像加载到GPU, 在init函数中
跑多光谱估计的代码,参考:https://github.com/caiyuanhao1998/MST-plus-plus
原代码dataset一次加载所有图像到cpu内存中
一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像
这种方法比较常用,读取图像的效率也高,但是cpu内存要够
from torch.utils.data import Dataset
import numpy as np
import random
import cv2
import h5py
import torch
class TrainDataset(Dataset):
def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
self.crop_size = crop_size
self.hypers = []
self.bgrs = []
self.arg = arg
h,w = 482,512 # img shape
self.stride = stride
self.patch_per_line = (w-crop_size)//stride+1
self.patch_per_colum = (h-crop_size)//stride+1
self.patch_per_img = self.patch_per_line*self.patch_per_colum
hyper_data_path = f'{data_root}/Train_spectral/'
bgr_data_path = f'{data_root}/Train_RGB/'
with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
hyper_list = [line.replace('\n','.mat') for line in fin]
bgr_list = [line.replace('mat','jpg') for line in hyper_list]
hyper_list.sort()
bgr_list.sort()
# hyper_list = hyper_list[:300]
# bgr_list = bgr_list[:300]
print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
for i in range(len(hyper_list)):
hyper_path = hyper_data_path + hyper_list[i]
if 'mat' not in hyper_path:
continue
with h5py.File(hyper_path, 'r') as mat:
hyper =np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [0, 2, 1])
bgr_path = bgr_data_path + bgr_list[i]
assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
bgr = cv2.imread(bgr_path)
if bgr2rgb:
bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
bgr = np.float32(bgr)
bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
bgr = np.transpose(bgr, [2, 0, 1]) # [3,482,512]
self.hypers.append(hyper)
self.bgrs.append(bgr)
mat.close()
print(f'Ntire2022 scene {i} is loaded.')
self.img_num = len(self.hypers)
self.length = self.patch_per_img * self.img_num
def arguement(self, img, rotTimes, vFlip, hFlip):
# Random rotation
for j in range(rotTimes):
img = np.rot90(img.copy(), axes=(1, 2))
# Random vertical Flip
for j in range(vFlip):
img = img[:, :, ::-1].copy()
# Random horizontal Flip
for j in range(hFlip):
img = img[:, ::-1, :].copy()
return img
def __getitem__(self, idx):
stride = self.stride
crop_size = self.crop_size
img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
bgr = self.bgrs[img_idx]
hyper = self.hypers[img_idx]
bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
rotTimes = random.randint(0, 3)
vFlip = random.randint(0, 1)
hFlip = random.randint(0, 1)
if self.arg:
bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)
def __len__(self):
return self.patch_per_img*self.img_num
二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU
这种方法可以处理大数据集,比如所有图像占用内存大于电脑内存的时候,用这种方法
但是由于读取图像放在了get_item中,训练的时候加载数据会比较慢。
class TrainDataset_single(Dataset):
def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
self.crop_size = crop_size
self.hypers = []
self.bgrs = []
self.arg = arg
self.bgr2rgb = bgr2rgb
h,w = 482,512 # img shape
self.stride = stride
self.patch_per_line = (w-crop_size)//stride+1
self.patch_per_colum = (h-crop_size)//stride+1
self.patch_per_img = self.patch_per_line*self.patch_per_colum
hyper_data_path = f'{data_root}/Train_spectral/'
bgr_data_path = f'{data_root}/Train_RGB/'
with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
hyper_list = [line.replace('\n','.mat') for line in fin]
bgr_list = [line.replace('mat','jpg') for line in hyper_list]
hyper_list.sort()
bgr_list.sort()
# hyper_list = hyper_list[:300]
# bgr_list = bgr_list[:300]
print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
for i in range(len(hyper_list)):
hyper_path = hyper_data_path + hyper_list[i]
bgr_path = bgr_data_path + bgr_list[i]
# if 'mat' not in hyper_path:
# continue
# with h5py.File(hyper_path, 'r') as mat:
# hyper =np.float32(np.array(mat['cube']))
# hyper = np.transpose(hyper, [0, 2, 1])
# assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
# bgr = cv2.imread(bgr_path)
# if bgr2rgb:
# bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
# bgr = np.float32(bgr)
# bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
# bgr = np.transpose(bgr, [2, 0, 1]) # [3,482,512]
self.hypers.append(hyper_path)
self.bgrs.append(bgr_path)
# mat.close()
print(f'Ntire2022 scene {i} is loaded.')
self.img_num = len(self.hypers)
self.length = self.patch_per_img * self.img_num
def arguement(self, img, rotTimes, vFlip, hFlip):
# Random rotation
for j in range(rotTimes):
img = np.rot90(img.copy(), axes=(1, 2))
# Random vertical Flip
for j in range(vFlip):
img = img[:, :, ::-1].copy()
# Random horizontal Flip
for j in range(hFlip):
img = img[:, ::-1, :].copy()
return img
def __getitem__(self, idx):
stride = self.stride
crop_size = self.crop_size
img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
bgr_path = self.bgrs[img_idx]
hyper_path = self.hypers[img_idx]
# if 'mat' not in hyper_path:
# continue
with h5py.File(hyper_path, 'r') as mat:
hyper =np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [0, 2, 1])
# assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
bgr = cv2.imread(bgr_path)
if self.bgr2rgb:
bgr = bgr[..., ::-1] #cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
bgr = np.float32(bgr)
bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
bgr = np.transpose(bgr, [2, 0, 1]) # [3,482,512]
bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
rotTimes = random.randint(0, 3)
vFlip = random.randint(0, 1)
hFlip = random.randint(0, 1)
if self.arg:
bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)
def __len__(self):
return self.patch_per_img*self.img_num
三(gpu,init)是将图像加载到GPU, 在init函数中
就是cpu内存不够不能使用方法一,且我们不像速度太慢不能使用方法二。
如果GPU显存比较大的时候,或者有多个GPU的时候,可以在init函数中将图像读取到若干个GPU中。
比如下面,将450张读取到gpu0, 另外450张读取到gpu1
这样TrainDataset_gpu[i] 返回的就是在gpu上的数据
"""
数据在不同的gpu上,不能使用dataloader
"""
class TrainDataset_gpu(Dataset):
def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
self.crop_size = crop_size
self.hypers = []
self.bgrs = []
self.arg = arg
self.bgr2rgb = bgr2rgb
h,w = 482,512 # img shape
self.stride = stride
self.patch_per_line = (w-crop_size)//stride+1
self.patch_per_colum = (h-crop_size)//stride+1
self.patch_per_img = self.patch_per_line*self.patch_per_colum
hyper_data_path = f'{data_root}/Train_spectral/'
bgr_data_path = f'{data_root}/Train_RGB/'
with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
hyper_list = [line.replace('\n','.mat') for line in fin]
bgr_list = [line.replace('mat','jpg') for line in hyper_list]
hyper_list.sort()
bgr_list.sort()
# hyper_list = hyper_list[:100]
# bgr_list = bgr_list[:100]
print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
for i in range(len(hyper_list)):
hyper_path = hyper_data_path + hyper_list[i]
bgr_path = bgr_data_path + bgr_list[i]
if 'mat' not in hyper_path:
continue
with h5py.File(hyper_path, 'r') as mat:
hyper =np.float32(np.array(mat['cube']))
hyper = np.transpose(hyper, [0, 2, 1])
assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
bgr = cv2.imread(bgr_path)
if bgr2rgb:
bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
bgr = np.float32(bgr)
bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
bgr = np.transpose(bgr, [2, 0, 1]) # [3,482,512]
if i < 450:
device = torch.device('cuda:0')
self.hypers.append(torch.from_numpy(hyper).to(device))
self.bgrs.append(torch.from_numpy(bgr).to(device))
elif i<900:
device = torch.device('cuda:1')
self.hypers.append(torch.from_numpy(hyper).to(device))
self.bgrs.append(torch.from_numpy(bgr).to(device))
# mat.close()
print(f'Ntire2022 scene {i} is loaded.')
self.img_num = len(self.hypers)
self.length = self.patch_per_img * self.img_num
def arguement(self, img, hyper, rotTimes, vFlip, hFlip):
# Random rotation
if rotTimes:
img = torch.rot90(img, rotTimes, [1, 2])
hyper = torch.rot90(hyper, rotTimes, [1, 2])
# Random vertical Flip
if vFlip:
#img = img[:, :, ::-1]
img = torch.flip(img, dims=[1])
hyper = torch.flip(hyper, dims=[1])
# Random horizontal Flip
if hFlip:
#img = img[:, ::-1, :]
img = torch.flip(img, dims=[2])
hyper = torch.flip(hyper, dims=[2])
return img, hyper
def __getitem__(self, idx):
stride = self.stride
crop_size = self.crop_size
img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
bgr = self.bgrs[img_idx]
hyper = self.hypers[img_idx]
bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
rotTimes = random.randint(0, 3)
vFlip = random.randint(0, 1)
hFlip = random.randint(0, 1)
if self.arg:
bgr, hyper = self.arguement(bgr, hyper, rotTimes, vFlip, hFlip)
return bgr, hyper # np.ascontiguousarray(bgr.cpu().numpy()), np.ascontiguousarray(hyper.cpu().numpy())
def __len__(self):
return self.patch_per_img*self.img_num
但是读取到GPU之后,训练的时候 好像不能使用dataloader, 容易报错。
这个时候自己设计一个 批处理函数,和shuffle
# 1.加载数据集
train_data = TrainDataset_gpu(data_root=opt.data_root, crop_size=opt.patch_size, bgr2rgb=True, arg=True, stride=opt.stride)
# 2. 获取数据集的长度, 使是batch_size的倍数, 打乱顺序
inddd = np.arange(len(train_data))
l =len(inddd) - (len(inddd)%opt.batch_size)
inddd2 = np.random.permutation(inddd)[:l]
inddd2 = inddd2.reshape(-1, opt.batch_size) #batch num, batch size
print(len(train_data), len(inddd)%opt.batch_size, inddd2.shape)
# 3. 读取每一个batch的图像
for i in range(inddd2.shape[0]):
t0 = time.time()
# 检索batch size个图像拼接为一个batch
inddd3 = inddd2[i]
#print('i, len, curlist:',i, len(inddd2), inddd3)
images = []
labels = []
for j in inddd3:
image, label = train_data[j]
image = image[None, ...]
label = label[None, ...]
# print(i, j, image.shape, label.shape)
# cv2.imwrite(f'{i:9d}_{j:4d}_image.png', (image[0].cpu().numpy().transpose(1,2,0)[...,[2,1,0]]*255).astype(np.uint8))
# cv2.imwrite(f'{i:9d}_{j:4d}_label.png', (label[0].cpu().numpy().transpose(1,2,0)[...,[5,15,25]]*255).astype(np.uint8))
images.append(image.cpu())
labels.append(label.cpu())
images = torch.cat(images, 0)
labels = torch.cat(labels, 0)
#print(images.shape, labels.shape)
labels = labels.cuda()
images = images.cuda()