记录详细的pytorch文本情感分类实战过程
word embedding APItorch.nn.Embedding(num_embeddings, embedding_dim)参数介绍:1.num_embeddings:词典的大小 (当前词典中不重复词的个数)2.embedding_dim:embedding的维度(用多长的向量表示我们的词语)使用方法:embedding = nn.Embedding(vocab_size, 300) #
word embedding API
torch.nn.Embedding(num_embeddings, embedding_dim)
参数介绍:1.num_embeddings:词典的大小 (当前词典中不重复词的个数)
2.embedding_dim:embedding的维度(用多长的向量表示我们的词语)
使用方法:embedding = nn.Embedding(vocab_size, 300) # 实例化 embedding维度为300维
input_embeded = embedding(input) # 进行embedding操作
数据的形状变化
如果每个batch中的每个句子有十个词语,经过形状为[20, 4]的word embedding 之后,原来的句子会变成什么样的形状?
//这里的20是指整个词库中不重复的词的个数,4是指我们想要embedding的维度 实际就是一个降维的过程 我们根据实际情况选择embedding_dim
每个词语用长度为4的向量表示,所以,最终会变成[batch_size , 10 , 4]的形状
//这里的10就是seq_len 句子长度
所以形状的变化为:[batch_size, seq_len] -----[batch_size, seq_len, embedding_dim]
这里插入一张图 方便理解word embedding的过程
思路分析:
首先可以把上述问题定义成一个分类问题,情感评分为1-10,10个类别,那么我们大致的流程如下:
1.准备数据集
2.构建模型
3.模型训练
4.模型评估
准备数据集
需要注意的点:1.如何完成基础的dataset和dataloader的准备
2.如何解决每个batch中的文本长度不一致的问题
3.如何解决每个batch中的文本转化为序列的问题
基础Dataset的准备
简单的数据预处理 这里涉及到正则表达式 re.sub re.S 和flags的操作:
从上面的代码中可以看到re.sub()方法中含有5个参数
(1)pattern:该参数表示正则中的模式字符串;
(2)repl:该参数表示要替换的字符串(即匹配到pattern后替换为repl),也可以是个函数;
(3)string:该参数表示要被处理(查找替换)的原始字符串;
(4)count:可选参数,表示是要替换的最大次数,而且必须是非负整数,该参数默认为0,即所有的匹配都会被替换;
(5)flags:可选参数,表示编译时用的匹配模式(如忽略大小写、多行模式等),数字形式,默认为0。 这里的flags = re.S
使 “.” 特殊字符完全匹配任何字符,包括换行;没有这个标志, “.” 匹配除了换行符外的任何字符。
strip()的用法:
str.strip() : 去除字符串两边的空格
str.lstrip() : 去除字符串左边的空格
str.rstrip() : 去除字符串右边的空格
注:此处的空格包含’\n’, ‘\r’, ‘\t’, ’ ’
def tokenize(text):
if __name__ == '__main__':
filters = ['!', '"', '#', '$', '%', '&', '\(',
'\)', '\*', '\+', ',', '-', '\.', '/', ':', ';',
'<', '=', '>', '\?', '@',
'\[', '\\', '\]', '^', '_', '~', '\{',
'\|', '\}', '`', '\t', '\x97', '\x96', '“', '”']
text = re.sub("<.*?>", " ", text, flags=re.S)
text = re.sub("|".join(filters), " ", text, flags=re.s)
return [i.strip() for i in text.split()]
整个数据准备阶段代码
import torch
from torch.utils.data import Dataset, DataLoader
import os
import re
data_base_path ="./data"
# 1.定义tokenize的方法
def tokenize(content):
filters = ['!', '"', '#', '$', '%', '&', '\(',
'\)', '\*', '\+', ',', '-', '\.', '/', ':', ';',
'<', '=', '>', '\?', '@',
'\[', '\\', '\]', '^', '_', '~', '\{',
'\|', '\}', '`', '\t', '\x97', '\x96', '“', '”']
content = re.sub("<.*?>", " ", content, flags=re.S) # 替换成空字符
content = re.sub("|".join(filters), " ", content, flags=re.S)
tokens = [i.strip().lower() for i in content.split()] # 分词操作
return tokens
# 2.准备dataset
class ImdbDataset(Dataset):
def __init__(self, train=True): # 1.文件的路径确定
self.train_data_path = "./data/train"
self.test_data_path = "./data/test"
data_path = self.train_data_path if train else self.test_data_path
# 把所有的文件名放入列表
temp_data_path = [os.path.join(data_path, "pos"), os.path.join(data_path,"neg")]
self.total_file_path = [] # 所有的评论的文件的路径
for path in temp_data_path:
file_name_list = os.listdir(path) # listdir返回路径中文件夹的文件名
file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith(".txt")] # 过滤 只要.txt结尾的
self.total_file_path.extend(file_path_list) # append是追加一个list extend是把两个list拼接起来
def __getitem__(self, index):
file_path = self.total_file_path[index] # 读取当前对应位置的filepath 拿到filepath之后 我们要获取他们的内容和labels
# 获取label
label_str = file_path.split('/')[-2]
label = 0 if label_str == "neg" else 1
# 获取内容
content = open(file_path).read()
tokens = tokenize(content)
return tokens, label
def __len__(self):
return len(self.total_file_path)
# 准备dataloader 定义一个方法 选择train和test模式
def get_dataloader(train=True):
imdb_dataset = ImdbDataset(train)
data_loader = DataLoader(imdb_dataset, batch_size=1, shuffle=True)
return data_loader
if __name__ == "__main__": # 这里的 if name 一定要和class对齐!!!
# imdb_dataset = ImdbDataset()
# print(imdb_dataset[0])
# my_str = 'It seems more than passing strange that such utter dreck as "Dukes of Hazzard" and "The Hills Have Eyes" (the new version) can find'
# print(tokenize(my_str))
for idx, (input, target) in enumerate(get_dataloader()):
print(idx)
print(input)
print(target)
break
这里插入collate_fn实现方法
collate_fn的默认值为torch自定义的default_collate,collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错
解决方法:1.考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
2.考虑自定义一个collate_fn,然后观察结果:
def collate_fn(batch):
# batch 是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果
batch = list(zip(*batch))
labels = torch.tensor(batch[0], dtype=torch.int32)
texts = batch[1]
del batch
return labels, texts
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True,collate_fn=collate_fn)
# 此时输出正常
更多推荐
所有评论(0)