关于pytorch的加载数据,cpu init, cpu getitem, gpu init

文章目录

    • 一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像
    • 二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU
    • 三(gpu,init)是将图像加载到GPU, 在init函数中

跑多光谱估计的代码,参考:https://github.com/caiyuanhao1998/MST-plus-plus
原代码dataset一次加载所有图像到cpu内存中

一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像

这种方法比较常用,读取图像的效率也高,但是cpu内存要够

from torch.utils.data import Dataset
import numpy as np
import random
import cv2
import h5py
import torch
class TrainDataset(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:300]
        # bgr_list = bgr_list[:300]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper =np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            bgr_path = bgr_data_path + bgr_list[i]
            assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            self.hypers.append(hyper)
            self.bgrs.append(bgr)
            mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, rotTimes, vFlip, hFlip):
        # Random rotation
        for j in range(rotTimes):
            img = np.rot90(img.copy(), axes=(1, 2))
        # Random vertical Flip
        for j in range(vFlip):
            img = img[:, :, ::-1].copy()
        # Random horizontal Flip
        for j in range(hFlip):
            img = img[:, ::-1, :].copy()
        return img

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr = self.bgrs[img_idx]
        hyper = self.hypers[img_idx]
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
            hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return self.patch_per_img*self.img_num

二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU

这种方法可以处理大数据集,比如所有图像占用内存大于电脑内存的时候,用这种方法
但是由于读取图像放在了get_item中,训练的时候加载数据会比较慢。

class TrainDataset_single(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        self.bgr2rgb = bgr2rgb
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:300]
        # bgr_list = bgr_list[:300]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
          
            bgr_path = bgr_data_path + bgr_list[i]
            
            # if 'mat' not in hyper_path:
            #     continue
            # with h5py.File(hyper_path, 'r') as mat:
            #     hyper =np.float32(np.array(mat['cube']))
            # hyper = np.transpose(hyper, [0, 2, 1])
            # assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            # bgr = cv2.imread(bgr_path)
            # if bgr2rgb:
            #     bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            # bgr = np.float32(bgr)
            # bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            # bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            self.hypers.append(hyper_path)
            self.bgrs.append(bgr_path)
            # mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, rotTimes, vFlip, hFlip):
        # Random rotation
        for j in range(rotTimes):
            img = np.rot90(img.copy(), axes=(1, 2))
        # Random vertical Flip
        for j in range(vFlip):
            img = img[:, :, ::-1].copy()
        # Random horizontal Flip
        for j in range(hFlip):
            img = img[:, ::-1, :].copy()
        return img

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr_path = self.bgrs[img_idx]
        hyper_path = self.hypers[img_idx]
        
        # if 'mat' not in hyper_path:
        #     continue
        with h5py.File(hyper_path, 'r') as mat:
            hyper =np.float32(np.array(mat['cube']))
        hyper = np.transpose(hyper, [0, 2, 1])
        # assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
        bgr = cv2.imread(bgr_path)
        if self.bgr2rgb:
            bgr = bgr[..., ::-1] #cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        bgr = np.float32(bgr)
        bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
        bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)
            hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return self.patch_per_img*self.img_num

三(gpu,init)是将图像加载到GPU, 在init函数中

就是cpu内存不够不能使用方法一,且我们不像速度太慢不能使用方法二。
如果GPU显存比较大的时候,或者有多个GPU的时候,可以在init函数中将图像读取到若干个GPU中。

比如下面,将450张读取到gpu0, 另外450张读取到gpu1
这样TrainDataset_gpu[i] 返回的就是在gpu上的数据

"""
数据在不同的gpu上,不能使用dataloader
"""    
class TrainDataset_gpu(Dataset):
    def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):
        self.crop_size = crop_size
        self.hypers = []
        self.bgrs = []
        self.arg = arg
        self.bgr2rgb = bgr2rgb
        h,w = 482,512  # img shape
        self.stride = stride
        self.patch_per_line = (w-crop_size)//stride+1
        self.patch_per_colum = (h-crop_size)//stride+1
        self.patch_per_img = self.patch_per_line*self.patch_per_colum

        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'

        with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n','.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        
        # hyper_list = hyper_list[:100]
        # bgr_list = bgr_list[:100]
        print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
          
            bgr_path = bgr_data_path + bgr_list[i]
            
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper =np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]
            
            if i < 450:
                device = torch.device('cuda:0')
                self.hypers.append(torch.from_numpy(hyper).to(device))
                self.bgrs.append(torch.from_numpy(bgr).to(device))
            elif i<900:
                device = torch.device('cuda:1')
                self.hypers.append(torch.from_numpy(hyper).to(device))
                self.bgrs.append(torch.from_numpy(bgr).to(device))
            # mat.close()
            print(f'Ntire2022 scene {i} is loaded.')
        self.img_num = len(self.hypers)
        self.length = self.patch_per_img * self.img_num

    def arguement(self, img, hyper, rotTimes, vFlip, hFlip):
         # Random rotation
        if rotTimes:
            img = torch.rot90(img, rotTimes, [1, 2])
            hyper = torch.rot90(hyper, rotTimes, [1, 2])
        # Random vertical Flip
        if vFlip:
            #img = img[:, :, ::-1]
            img = torch.flip(img, dims=[1])
            hyper = torch.flip(hyper, dims=[1])
        # Random horizontal Flip
        if hFlip:
            #img = img[:, ::-1, :]
            img = torch.flip(img, dims=[2])
            hyper = torch.flip(hyper, dims=[2])
        return img, hyper

    def __getitem__(self, idx):
        stride = self.stride
        crop_size = self.crop_size
        img_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_img
        h_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_line
        bgr = self.bgrs[img_idx]
        hyper = self.hypers[img_idx]
        
        bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]
        hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]
        rotTimes = random.randint(0, 3)
        vFlip = random.randint(0, 1)
        hFlip = random.randint(0, 1)
        if self.arg:
            bgr, hyper = self.arguement(bgr, hyper, rotTimes, vFlip, hFlip)
        
        return bgr, hyper # np.ascontiguousarray(bgr.cpu().numpy()), np.ascontiguousarray(hyper.cpu().numpy()) 

    def __len__(self):
        return self.patch_per_img*self.img_num

但是读取到GPU之后,训练的时候 好像不能使用dataloader, 容易报错。

这个时候自己设计一个 批处理函数,和shuffle

# 1.加载数据集
train_data = TrainDataset_gpu(data_root=opt.data_root, crop_size=opt.patch_size, bgr2rgb=True, arg=True, stride=opt.stride)

# 2. 获取数据集的长度, 使是batch_size的倍数, 打乱顺序
inddd = np.arange(len(train_data))
l =len(inddd) -  (len(inddd)%opt.batch_size) 
inddd2 = np.random.permutation(inddd)[:l]
inddd2 = inddd2.reshape(-1, opt.batch_size) #batch num, batch size
print(len(train_data), len(inddd)%opt.batch_size, inddd2.shape)

# 3. 读取每一个batch的图像
for i in range(inddd2.shape[0]):
    t0 = time.time()
    # 检索batch size个图像拼接为一个batch
    inddd3 = inddd2[i]
    #print('i, len, curlist:',i, len(inddd2), inddd3)
    images = []
    labels = []
    for j in inddd3:
        image, label = train_data[j]
        image = image[None, ...]
        label = label[None, ...]
        # print(i, j, image.shape, label.shape)
        # cv2.imwrite(f'{i:9d}_{j:4d}_image.png', (image[0].cpu().numpy().transpose(1,2,0)[...,[2,1,0]]*255).astype(np.uint8))
        # cv2.imwrite(f'{i:9d}_{j:4d}_label.png', (label[0].cpu().numpy().transpose(1,2,0)[...,[5,15,25]]*255).astype(np.uint8))
        images.append(image.cpu())
        labels.append(label.cpu())
    images = torch.cat(images, 0)
    labels = torch.cat(labels, 0)
    #print(images.shape, labels.shape)
    
    labels = labels.cuda()
    images = images.cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/773008.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

开始尝试从0写一个项目--后端(一)

创建文件的目录结构 利用这个界面创建 序号 名称 说明 1 SEMS maven父工程&#xff0c;统一管理依赖版本&#xff0c;聚合其他子模块 2 sems-common 子模块&#xff0c;存放公共类&#xff0c;例如&#xff1a;工具类、常量类、异常类等 3 sems-pojo 子模块&#x…

硅纪元视角 | AI纳米机器人突破癌症治疗,精准打击肿瘤细胞

在数字化浪潮的推动下&#xff0c;人工智能&#xff08;AI&#xff09;正成为塑造未来的关键力量。硅纪元视角栏目紧跟AI科技的最新发展&#xff0c;捕捉行业动态&#xff1b;提供深入的新闻解读&#xff0c;助您洞悉技术背后的逻辑&#xff1b;汇聚行业专家的见解&#xff0c;…

打卡第2天----数组双指针,滑动窗口

今天是参与训练营第二天&#xff0c;这几道题我都看懂了&#xff0c;自己也能写出来了&#xff0c;实现思路很重要&#xff0c;万事开头难&#xff0c;希望我可以坚持下去。希望最后的结果是量变带来质变。 一、理解双指针思想 leetcode编号&#xff1a;977 不止是在卡尔这里…

深入探讨JavaScript中的队列,结合leetcode全面解读

前言 队列作为一种基本的数据结构&#xff0c;为解决许多实际问题提供了有效的组织和处理方式&#xff0c;对于提高系统的稳定性、可靠性和效率具有重要作用&#xff0c;所以理解队列是很重要的。 本文深入探讨JavaScript中的队列这种数据结构,结合leetcode题目讲解 题目直达…

接口测试工具Postman

Postman Postman介绍 开发API后&#xff0c;用于API测试的工具。在我们平时开发中&#xff0c;特别是需要与接口打交道时&#xff0c;无论是写接口还是用接口&#xff0c;拿到接口后肯定都得提前测试一下。在开发APP接口的过程中&#xff0c;一般接口写完之后&#xff0c;后端…

78110A雷达信号模拟软件

78110A雷达信号模拟软件 78110A雷达信号模拟软件(简称雷达信号模拟软件)主要用于模拟产生雷达发射信号和目标回波信号&#xff0c;软件将编译生成的雷达信号任意波数据下载到信号发生器中&#xff0c;主要是1466-V矢量信号发生器&#xff0c;可实现雷达信号模拟产生。软件可模…

TensorRT-Int8量化详解

int8量化是利用int8乘法替换float32乘法实现性能加速的一种方法 对于常规模型有&#xff1a;y kx b&#xff0c;此时x、k、b都是float32, 对于kx的计算使用float32的乘法 对于int8模型有&#xff1a;y tofp32(toint8(k) * toint8(x)) b&#xff0c;其中int8 * int8结果为in…

SpringBoot的热部署和日志体系

SpringBoot的热部署 每次修改完代码&#xff0c;想看效果的话&#xff0c;不用每次都重新启动代码&#xff0c;等待项目重启 这样就可以了 JDK官方提出的日志框架&#xff1a;Jul log4j的使用方式&#xff1a; &#xff08;1&#xff09;引入maven依赖 &#xff08;2&#x…

头歌资源库(20)最大最小数

一、 问题描述 二、算法思想 使用分治法&#xff0c;可以将数组递归地分割成两部分&#xff0c;直到数组长度为1或2。然后比较这两部分的最大、次大、次小、最小数&#xff0c;最终得到整个数组中的最大两个数和最小两个数。 算法步骤如下&#xff1a; 定义一个函数 findMinM…

uniapp/Android App上架三星市场需要下载所需要的SDK

只需添加以下一个权限在AndroidManifest.xml <uses-permission android:name"com.samsung.android.providers.context.permission.WRITE_USE_APP_FEATURE_SURVEY"/>uniapp开发的&#xff0c;需要在App权限配置中加入以上的额外权限&#xff1a;

Generative Modeling by Estimating Gradients of the Data Distribution

Generative Modeling by Estimating Gradients of the Data Distribution 本文介绍宋飏提出的带噪声扰动的基于得分的生成模型。首先介绍基本的基于得分的生成模型的训练方法&#xff08;得分匹配&#xff09;和采样方法&#xff08;朗之万动力学&#xff09;。然后基于流形假…

2024 年 亚太赛 APMCM (B题)中文赛道国际大学生数学建模挑战赛 |洪水灾害数据分析 | 数学建模完整代码+建模过程全解全析

当大家面临着复杂的数学建模问题时&#xff0c;你是否曾经感到茫然无措&#xff1f;作为2022年美国大学生数学建模比赛的O奖得主&#xff0c;我为大家提供了一套优秀的解题思路&#xff0c;让你轻松应对各种难题&#xff01; 完整内容可以在文章末尾领取&#xff01; 该段文字…

HTML内容爬取:使用Objective-C进行网页数据提取

网页爬取简介 网页爬取&#xff0c;通常被称为网络爬虫或爬虫&#xff0c;是一种自动浏览网页并提取所需数据的技术。这些数据可以是文本、图片、链接或任何网页上的元素。爬虫通常遵循一定的规则&#xff0c;访问网页&#xff0c;解析页面内容&#xff0c;并存储所需信息。 …

自动化立体仓库出入库能力及堆垛机节拍

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》人俱乐部 完整版文件和更多学习资料&#xff0c;请球友到知识星球【智能仓储物流技术研习社】自行下载 自动化立体仓库的出入库能力、堆垛机节拍以…

用720云搭建数字孪生VR智慧安防系统,赋能安防升级!

“安全防范"一直是我国城镇化发展进程中重点关注的工作板块&#xff0c;随着时代发展需求与科技的日新月异&#xff0c;安防行业正在积极融合VR3D数字孪生技术&#xff0c;升级安防数字基础设施和安防产品服务创新。 今年2月&#xff0c;《数字中国建设整体布局规划》的出…

Pycharm的终端(Terminal)中切换到当前项目所在的虚拟环境

1.在Pycharm最下端点击终端/Terminal, 2.点击终端窗口最上端最右边的∨&#xff0c; 3.点击Command Prompt&#xff0c;切换环境&#xff0c; 可以看到现在环境已经由默认的PS(Window PowerShell)切换为项目所使用的虚拟环境。 4.更近一步&#xff0c;如果想让Pycharm默认显示…

macOS使用Karabiner-Elements解决罗技鼠标G304连击、单击变双击的故障

记录一下罗技鼠标G304单击变双击的软件解决过程和方案&#xff08;适用于macOS&#xff0c; 如果是Windows&#xff0c;使用AutoHotKey也有类似解决办法、方案&#xff0c;改日提供&#xff09;&#xff1a; 背景&#xff1a;通过罗技Logitech G HUB软件对罗技的游戏鼠标侧键b…

1-4 NLP发展历史与我的工作感悟

1-4 NLP发展历史与我的工作感悟 主目录点这里 第一个重要节点&#xff1a;word2vec词嵌入 能够将无限的词句表示为有限的词向量空间&#xff0c;而且运算比较快&#xff0c;使得文本与文本间的运算有了可能。 第二个重要节点&#xff1a;Transformer和bert 为预训练语言模型发…

2024 世界人工智能大会暨人工智能全球治理高级别会议全体会议在上海举办,推动智能向善造福全人类

2024 年 7 月 4 日&#xff0c;2024 世界人工智能大会暨人工智能全球治理高级别会议-全体会议在上海世博中心举办。联合国以及各国政府代表、专业国际组织代表&#xff0c;全球知名专家、企业家、投资家 1000 余人参加了本次会议&#xff0c;围绕“以共商促共享&#xff0c;以善…

搜维尔科技:如何使用 SenseGlove Nova 加速手部运动功能的恢复

District XR 的VR 培训 5 年多来&#xff0c;District XR 一直在为最大的工业公司创建 VR 和 AR 项目。 客户&#xff1a;District XR 客户代表&#xff1a;尼古拉沃尔科夫 他的角色&#xff1a;District XR 首席执行官 面临解决的挑战 该公司正在寻找一种方法来加速身体伤…