#列出所有环境
#conda info --env
#conda activate mypytorch
#安装opencv环境
#pip3 install opencv-python


```python

```python
import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
#正则表达式匹配出最后的数字:12
#print(re.findall("(\d+)","flower")[-1])
#创建自定义DataSet类
class myDataSet(Dataset):
    #每个分类的子文件夹独立成一个标签数据集,标签例如flower0
    def __init__(self,rootdir,labeldir):
        self.rootdir=rootdir
        self.labeldir=labeldir
        self.imagePaths=os.path.join(rootdir,labeldir)
    '''
    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
        #bgr转rgb
        img = img[:, :, ::-1]
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label
    '''
    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=Image.open(imagePath)
        img = np.array(img)
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label
    def __len__(self):
        return len(self.imagePaths)
#使用r标识路径防止转义:
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
#生成各子数据集
for label in labelList:
    allDataSet.append(myDataSet(rootdir,label))
'''
reduce() 函数会对参数序列中元素进行累积。
函数将一个数据集合(链表,元组等)中的所有数据进行下列操作:
用传给 reduce 中的函数 function(有两个参数)先对集合中的第 1、2 个元素进行操作,得到的结果再与第三个数据用 function 函数运算,最后得到一个结果。
'''
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
#347
print(trainDataSet.__len__())
#(500, 689, 3)
print(trainDataSet[0][0].shape)
print(trainDataSet[0][1])

下面做数据集标签格式转换的工作:

import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
#正则表达式匹配出最后的数字:12
#print(re.findall("(\d+)","flower")[-1])
#创建自定义DataSet类
class myDataSet(Dataset):
    #每个分类的子文件夹独立成一个标签数据集,标签例如flower0
    def __init__(self,rootdir,labeldir):
        self.rootdir=rootdir
        self.labeldir=labeldir
        self.imagePaths=os.path.join(rootdir,labeldir)
    '''
    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath)
        img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
        #bgr转rgb
        img = img[:, :, ::-1]
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,label
    '''
    #item作为编号:opencv版本
    def __getitem__(self, item):
        imagePath1=os.listdir(self.imagePaths)[item]
        imagePath=os.path.join(self.imagePaths,imagePath1)
        img=Image.open(imagePath)
        img = np.array(img)
        labelComopent =re.findall("(\d+)",self.labeldir)
        #如果在标签中取不出对应tag
        if len(labelComopent)==0:
            raise ValueError
        label=int(labelComopent[-1])
        return img,(label,imagePath1)
    def __len__(self):
        return len(self.imagePaths)
#使用r标识路径防止转义:
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
#生成各子数据集
for label in labelList:
    allDataSet.append(myDataSet(rootdir,label))
'''
reduce() 函数会对参数序列中元素进行累积。
函数将一个数据集合(链表,元组等)中的所有数据进行下列操作:
用传给 reduce 中的函数 function(有两个参数)先对集合中的第 1、2 个元素进行操作,得到的结果再与第三个数据用 function 函数运算,最后得到一个结果。
'''
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
#347
print(trainDataSet.__len__())
#(500, 689, 3)
# print(trainDataSet[0][0].shape)
#image_0001.jpg
#print(trainDataSet[0][1][1])
#label写入text文件:同时连接3部分:
rootdir=r"D:\17flowers2"
for data,label in trainDataSet:
    fileName=(label[1].split(".jpg")[0])
    with open(os.path.join(rootdir,"trainLabel","{}.txt".format(fileName)),'w') as f:
        f.write(str(label[0]))





Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐