Seq2Seq识别车牌项目demo
一、Seq相关概念1、Seq2Seq的作用:Seq2S 是一类特殊的 RNN,在机器翻译、文本自动摘要和语音识别中有着成功的应用,可以应用它来解决输入和输出不等长问题,典型的是在语言的互译过程中,输入和输出不等长。2、SeqSeq的结构分为两部分:编码器和解码器其中:编码器和解码器分别是两个RNN的网络,编码器用来分析输入序列,,解码器用来生成输出序列。此处可以使用的 RNN 变体:RNN 可以是
·
一、Seq相关概念
1、Seq2Seq的作用:
Seq2S 是一类特殊的 RNN,在机器翻译、文本自动摘要和语音识别中有着成功的应用,可以应用它来解决输入和输出不等长问题,典型的是在语言的互译过程中,输入和输出不等长。
2、SeqSeq的结构
分为两部分:编码器和解码器
其中:编码器和解码器分别是两个RNN的网络,编码器用来分析输入序列,,解码器用来生成输出序列。
此处可以使用的 RNN 变体:
- RNN 可以是单向的或双向的,后者将捕捉双向的长时间依赖关系。
- RNN 可以有多个隐藏层,层数的选择对于优化来说至关重要...更深的网络可以学到更多知识,另一方面,训练需要花费很长时间而且可能会过度拟合。
- RNN 可以有多个隐藏层,层数的选择对于优化来说至关重要...更深的网络可以学到更多知识,另一方面,训练需要花费很长时间而且可能会过度拟合。
- RNN 可以具有嵌入层,其将单词映射到嵌入空间中,在嵌入空间中相似单词的映射恰好也非常接近。
- RNN 可以使用简单的重复性单元、LSTM、窥孔 LSTM 或者 GRU。
二、识别车牌
1、制作车牌数据集
注意:
- 模拟的车牌的图片名为7个数字对应车牌中的字符,用作网络学习的标签
- 为了增加模拟车牌的多样性,可以在生成的模拟车牌中添加噪声、仿射变换....等操作
import numpy as np
import cv2
from PIL import ImageFont, ImageDraw, Image
import os
char_index_map = {"京": 0, "沪": 1, "津": 2, "渝": 3, "冀": 4, "晋": 5, "蒙": 6, "辽": 7, "吉": 8, "黑": 9, "苏": 10, "浙": 11, "皖": 12,
"闽": 13, "赣": 14, "鲁": 15, "豫": 16, "鄂": 17, "湘": 18, "粤": 19, "桂": 20, "琼": 21, "川": 22, "贵": 23, "云": 24,
"藏": 25, "陕": 26, "甘": 27, "青": 28, "宁": 29, "新": 30,
"0": 31, "1": 32, "2": 33, "3": 34, "4": 35, "5": 36, "6": 37, "7": 38, "8": 39, "9": 40,
"A": 41, "B": 42, "C": 43, "D": 44, "E": 45, "F": 46, "G": 47, "H": 48,"J": 49, "K": 50, "L": 51, "M": 52, "N": 53,
"P": 54, "Q": 55, "R": 56, "S": 57, "T": 58, "U": 59, "V": 60,"W": 61, "X": 62, "Y": 63, "Z": 64}
chars_list = ["京", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "皖",
"闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂","琼", "川", "贵", "云",
"藏", "陕", "甘", "青", "宁", "新",
"0", "1", "2", "3", "4", "5","6", "7", "8", "9",
"A", "B", "C", "D", "E", "F", "G", "H","J", "K", "L", "M", "N",
"P", "Q", "R", "S", "T", "U", "V","W", "X", "Y", "Z"]
font_Image = ImageFont.truetype("./simhei.ttf", 66, encoding="utf-8")
font_Image2=ImageFont.truetype("./simhei.ttf", 72, encoding="utf-8")
abbr_chars = ["京", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "皖", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂",
"琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新"]
letter_chars =["A","B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z"]
num_chars =["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
zoords=[(15,30),(78,30),(162,30),(215,30),(265,30),(315,30),(365,30)] #写的字符的位置
'''
isPosRandom 是否随机位置,样本不规则方便训练后在真实图像上有用
isGauss 是否高斯模糊,方便训练,当然实际中可能还需要你做仿射变换
first_char_index:由于训练样本必须包含所有的汉字,这里传递的什么汉字就生成什么汉字开头的车牌
'''
plate_type=["green","blue","yellow"]
plate_colors={"blue":(255,255,255),"green":(0,0,0),"yellow":(0,0,0)}
img = np.ones((150,420),dtype=np.uint8) #random.random()方法后面不能加数据类型
# img = np.random.random((3,3)) #生成随机数都是小数无法转化颜色,无法调用cv2.cvtColor函数
# img[0,0]=100
# img[0,1]=150
# img[0,2]=255
y_bg = cv2.cvtColor(img,cv2.COLOR_GRAY2BGR) #黄色车牌的背景
y_bg[:,:,0] = 0
y_bg[:,:,1] = 255
y_bg[:,:,2] = 255
g_bg= cv2.cvtColor(img,cv2.COLOR_GRAY2BGR) #绿色车牌的背景
g_bg[:,:,0] = 0
g_bg[:,:,1] = 255
g_bg[:,:,2] = 0
b_bg= cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) #蓝色车牌的背景
b_bg[:,:,0] = 255
b_bg[:,:,1] = 0
b_bg[:,:,2] = 0
bg_dic = {"yellow":y_bg, "blue":b_bg,"green":g_bg}
def getPlate(bg_path_root="",plate_type="blue",isPosRandom=False,isGauss=False,first_char_index=-1):
color = plate_colors[plate_type] #获取背景色
if(bg_path_root==""):bg_path_root="./plate"
bg_path = (bg_path_root+"/{}.png").format(plate_type) #生成背景
cv2.imwrite(bg_path, np.uint8(bg_dic[plate_type])) #生成背景图
img_bg = cv2.imread(bg_path, 1) #读取背景,0表示灰度模式,1表示彩色模式,2表示读取的包括透明通道
chars_index_arr=[]
try:
image_pil = Image.fromarray(cv2.cvtColor(img_bg, cv2.COLOR_BGR2RGB)) #将背景转为PIL格式
imDraw = ImageDraw.Draw(image_pil)
for i in range(7): #共7个字符
if(first_char_index>=0 and i==0):
first_char_index=first_char_index % len(abbr_chars)
char=abbr_chars[first_char_index]
else:
char=getCharByIndex(i)
offset_x=np.random.randint(-5,5)
offset_y=np.random.randint(-5,5)
zoord=zoords[i]
if(isPosRandom): zoord=(zoords[i][0]+offset_x,zoords[i][1]+offset_y) #设置是否对字体进行偏移(某些字体本身会进行偏移,需要进行调整)
chars_index_arr.append(char_index_map[char])
if (i == 0):
imDraw.text(zoord, text=char, font=font_Image, fill=color, stroke_width=1,stroke_fill=color)
else:
imDraw.text(zoord, text=char, font=font_Image2, fill=color, stroke_width=1,stroke_fill=color, align="center")
img_cv2 = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR)
if(isGauss): #是否进行高斯模糊
img_cv2=cv2.GaussianBlur(img_cv2,(9,9),0)
return img_cv2,chars_index_arr
except Exception as e:
print(e)
pass
return None,None
def getCharByIndex(index):
if(index==0):
r_index = np.random.randint(0,len(abbr_chars))
return abbr_chars[r_index] #在第1个位置随机返回一个汉字
pass
elif(index==1):
r_index = np.random.randint(0, len(letter_chars))
return letter_chars[r_index] #在第2个位置随机返回一个字母
pass
else:
new_chars=letter_chars.copy() #其它位置是字母和数字构成
new_chars.extend(num_chars)
r_index = np.random.randint(0, len(new_chars))
return new_chars[r_index]
pass
def getCharsByIndexs(index_list): #获取车牌的内容
plate_chars = ""
for i in range(len(index_list)):
v = index_list[i]
plate_chars += chars_list[v]
return plate_chars
if __name__ == '__main__':
# plate1,charIndexs1=getPlate("","yellow",True,True,13) #获取黄色的车牌
# plate2,charIndexs2=getPlate() #(默认)获取蓝色的车牌
# plate3,charIndexs3=getPlate("","green") #获取绿色的车牌
#
# print(charIndexs1)#标签
# print(getCharsByIndexs(charIndexs1))#字符
# cv2.imshow("1", plate1)
# cv2.imshow("2", plate2)
# cv2.imshow("3", plate3)
# cv2.waitKey(0)
# 车牌生成器
for j in range(3): #3种颜色的背景所以循环3次
_plate_type = plate_type[j]
for i in range(1000): #每种颜色的车牌模拟生成的总数
first_char_index = i
isPosRandom = np.random.randint(0, 2) #是否进行随机位置写入
isGauss = np.random.randint(0, 2) #是否模糊
image, index_list = getPlate("", _plate_type, isPosRandom, isPosRandom,
first_char_index)
if not os.path.exists("./plate"):
os.makedirs("./plate")
image_path = r"./plate"
index_chars = ""
for v in index_list:
index_chars += str(v) + ","
index_chars = index_chars.rstrip(",")
cv2.imwrite("{0}/{1}.jpg".format(image_path, index_chars), image) #写入生成的车牌(图片名为7个字符在字典中对应的值)
样本示例如下:
2.搭建网络训练
(1)采样文件
注意:此处使用使用二值化转为黑白形式的图片学习更容易
import os
import torch
import numpy as np
from PIL import Image
import torch.utils.data as data
from torchvision import transforms
import cv2
from sklearn.preprocessing import OneHotEncoder
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])
class Sampling(data.Dataset):
def __init__(self,root):
self.transform = data_transforms
self.imgs = []
self.labels = []
for filenames in os.listdir(root):
x = os.path.join(root,filenames)
y = filenames.split('.')[0] #图片的名字就是里边的数字
y = y.split(",")
# print(x)
# print(y)
self.imgs.append(x) #将图片的绝对路径放在一起
self.labels.append(y) #标签就是文件名(数字内容)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img_path = self.imgs[index]
img = cv2.imread(img_path,1)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 转灰度图
ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) #二值化处理
img = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
img = Image.fromarray(img)
# print(type(img))
# cv2.imshow("", img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# img = Image.open(img_path)
img = self.transform(img)
label = self.labels[index]
# print('label ',label)
label = self.one_hot(label)
# print('label ', label)
return img,label
def one_hot(self,x):
z = np.zeros(shape=[7,65])
for i in range(7):
# print(x)
index = int(x[i])
z[i][index] = 1
return z
if __name__ == '__main__':
samping = Sampling("./plate")
dataloader = data.DataLoader(samping,1,shuffle=True)
for i,(img,label) in enumerate(dataloader):
# print(i)
print(img.shape) #torch.Size([64, 3, 150, 420])
print(label.shape) #torch.Size([64, 7, 65])
(2)训练网络
注意:主网络分为解码器和编码器,此处使用了LSTM
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import myseq_sampling
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(450,512),
nn.BatchNorm1d(num_features=512),
nn.ReLU()
)
self.lstm = nn.LSTM(input_size=512,
hidden_size=512,
num_layers=1,
batch_first=True)
def forward(self, x):
x = x.reshape(-1,450,420).permute(0,2,1) #(N,3x150,420)-->(N,420,150)
x = x.reshape(-1,450) #(Nx420,450)
fc1 = self.fc1(x)
fc1 = fc1.reshape(-1, 420, 512)
lstm,(h_n,h_c) = self.lstm(fc1,None)
out = lstm[:,-1,:]
return out
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.lstm = nn.LSTM(input_size=512,
hidden_size=512,
num_layers=2,
batch_first=True)
self.out = nn.Linear(512,65)
def forward(self,x):
x = x.reshape(-1,1,512)
x = x.expand(-1,7,512)
lstm,(h_n,h_c) = self.lstm(x,None)
y1 = lstm.reshape(-1,512)
out = self.out(y1)
output = out.reshape(-1,7,65)
return output
class MainNet (nn.Module):
def __init__(self):
super(MainNet, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
encoder = self.encoder(x)
# print(' e ',encoder.shape)
decoder = self.decoder(encoder)
return decoder
if __name__ == '__main__':
BATCH = 64
EPOCH = 100000
save_path = './my_param/seq2seq.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MainNet().to(device)
# batch_x = torch.randn([64,3,150,420]).to(device)
# output = net(batch_x)
# print(output.shape)
# exit()
opt = torch.optim.Adam(net.parameters())
loss_func = nn.MSELoss()
if os.path.exists(save_path):
net.load_state_dict(torch.load(save_path))
else:
print("No Params!")
train_data = myseq_sampling.Sampling(root="./plate")
train_loader = data.DataLoader(dataset=train_data,
batch_size=BATCH, shuffle=True,num_workers=4)
for epoch in range(EPOCH):
for i, (x, y) in enumerate(train_loader):
batch_x = x.to(device)
batch_y = y.float().to(device)
output = net(batch_x)
# print(output.shape) #torch.Size([64, 7, 65])
# print(batch_y.shape) #torch.Size([64, 7, 65])
loss = loss_func(output,batch_y)
if i % 10 == 0:
print(loss)
opt.zero_grad()
loss.backward()
opt.step()
if i % 100 == 0:
label_y = torch.argmax(y,2).detach().numpy()
out_y = torch.argmax(output,2).cpu().detach().numpy()
accuracy = np.sum(out_y == label_y,dtype=np.float32)/(BATCH * 7)
print("epoch:{},i:{},loss:{:.6f},acc:{:.2f}%".format(epoch,i,loss.item(),accuracy * 100))
print("label_y:",label_y[0])
print("out_y:",out_y[0])
torch.save(net.state_dict(), save_path)
样本比较少时可以训练处过拟合版本:
参考资料:
http://c.biancheng.net/view/1947.html
http://zh.gluon.ai/chapter_natural-language-processing/seq2seq.html
https://blog.csdn.net/Jerr__y/article/details/53749693
更多推荐
已为社区贡献2条内容
所有评论(0)