word embedding API

 

torch.nn.Embedding(num_embeddings, embedding_dim)

参数介绍:1.num_embeddings:词典的大小 (当前词典中不重复词的个数)

                  2.embedding_dim:embedding的维度(用多长的向量表示我们的词语)

使用方法:embedding = nn.Embedding(vocab_size, 300) # 实例化 embedding维度为300维

                  input_embeded = embedding(input)        # 进行embedding操作

数据的形状变化

如果每个batch中的每个句子有十个词语,经过形状为[20, 4]的word embedding 之后,原来的句子会变成什么样的形状?

//这里的20是指整个词库中不重复的词的个数,4是指我们想要embedding的维度 实际就是一个降维的过程 我们根据实际情况选择embedding_dim        

每个词语用长度为4的向量表示,所以,最终会变成[batch_size , 10 , 4]的形状 

//这里的10就是seq_len  句子长度

所以形状的变化为:[batch_size, seq_len] -----[batch_size, seq_len, embedding_dim] 

这里插入一张图 方便理解word embedding的过程



思路分析:

首先可以把上述问题定义成一个分类问题,情感评分为1-10,10个类别,那么我们大致的流程如下:

1.准备数据集

2.构建模型

3.模型训练

4.模型评估

准备数据集 

需要注意的点:1.如何完成基础的dataset和dataloader的准备

                         2.如何解决每个batch中的文本长度不一致的问题

                         3.如何解决每个batch中的文本转化为序列的问题

基础Dataset的准备

 简单的数据预处理 这里涉及到正则表达式 re.sub re.S 和flags的操作:

 从上面的代码中可以看到re.sub()方法中含有5个参数
(1)pattern:该参数表示正则中的模式字符串;
(2)repl:该参数表示要替换的字符串(即匹配到pattern后替换为repl),也可以是个函数;
(3)string:该参数表示要被处理(查找替换)的原始字符串;
(4)count:可选参数,表示是要替换的最大次数,而且必须是非负整数,该参数默认为0,即所有的匹配都会被替换;
(5)flags:可选参数,表示编译时用的匹配模式(如忽略大小写、多行模式等),数字形式,默认为0。 这里的flags = re.S
    使 “.” 特殊字符完全匹配任何字符,包括换行;没有这个标志, “.” 匹配除了换行符外的任何字符。


strip()的用法:

str.strip() : 去除字符串两边的空格
str.lstrip() : 去除字符串左边的空格
str.rstrip() : 去除字符串右边的空格

注:此处的空格包含’\n’, ‘\r’, ‘\t’, ’ ’


def tokenize(text):
    if __name__ == '__main__':
        filters = ['!', '"', '#', '$', '%', '&', '\(',
                   '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';',
                   '<', '=', '>', '\?', '@',
                   '\[', '\\', '\]', '^', '_', '~', '\{',
                   '\|', '\}', '`', '\t', '\x97', '\x96', '“', '”']
        text = re.sub("<.*?>", " ", text, flags=re.S)
        text = re.sub("|".join(filters), " ", text, flags=re.s)
        return [i.strip() for i in text.split()]


整个数据准备阶段代码

import torch
from torch.utils.data import Dataset, DataLoader
import os
import re

data_base_path ="./data"
# 1.定义tokenize的方法


def tokenize(content):
    filters = ['!', '"', '#', '$', '%', '&', '\(',
                   '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';',
                   '<', '=', '>', '\?', '@',
                   '\[', '\\', '\]', '^', '_', '~', '\{',
                   '\|', '\}', '`', '\t', '\x97', '\x96', '“', '”']
    content = re.sub("<.*?>", " ", content, flags=re.S)  # 替换成空字符
    content = re.sub("|".join(filters), " ", content, flags=re.S)
    tokens = [i.strip().lower() for i in content.split()] # 分词操作
    return tokens

# 2.准备dataset


class ImdbDataset(Dataset):
    def __init__(self, train=True):   # 1.文件的路径确定
        self.train_data_path = "./data/train"
        self.test_data_path = "./data/test"
        data_path = self.train_data_path if train else self.test_data_path
        # 把所有的文件名放入列表
        temp_data_path = [os.path.join(data_path, "pos"), os.path.join(data_path,"neg")]
        self.total_file_path = []  # 所有的评论的文件的路径
        for path in temp_data_path:
            file_name_list = os.listdir(path)  # listdir返回路径中文件夹的文件名
            file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith(".txt")]  # 过滤 只要.txt结尾的
            self.total_file_path.extend(file_path_list) # append是追加一个list extend是把两个list拼接起来

    def __getitem__(self, index):
        file_path = self.total_file_path[index]  # 读取当前对应位置的filepath 拿到filepath之后 我们要获取他们的内容和labels
        # 获取label
        label_str = file_path.split('/')[-2]
        label = 0 if label_str == "neg" else 1
        # 获取内容
        content = open(file_path).read()
        tokens = tokenize(content)
        return tokens, label

    def __len__(self):
        return len(self.total_file_path)

# 准备dataloader 定义一个方法 选择train和test模式


def get_dataloader(train=True):
    imdb_dataset = ImdbDataset(train)
    data_loader = DataLoader(imdb_dataset, batch_size=1, shuffle=True)
    return data_loader


if __name__ == "__main__":  # 这里的 if name 一定要和class对齐!!!
    # imdb_dataset = ImdbDataset()
    # print(imdb_dataset[0])
    # my_str = 'It seems more than passing strange that such utter dreck as "Dukes of Hazzard" and "The Hills Have Eyes" (the new version) can find'
    # print(tokenize(my_str))
    for idx, (input, target) in enumerate(get_dataloader()):
        print(idx)
        print(input)
        print(target)
        break

这里插入collate_fn实现方法

collate_fn的默认值为torch自定义的default_collate,collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错

解决方法:1.考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误

                  2.考虑自定义一个collate_fn,然后观察结果:

def collate_fn(batch):
    # batch 是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True,collate_fn=collate_fn)

# 此时输出正常

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐