gitbook/PyTorch深度学习实战/docs/464152.md
2022-09-03 22:05:03 +08:00

300 lines
20 KiB
Markdown
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 24 | 文本分类如何使用BERT构建文本分类模型
你好,我是方远。
在第22节课我们一起学习了不少文本处理方面的理论其实文本分类在机器学习领域的应用也非常广泛。
比如说你现在是一个NLP研发工程师老板啪地一下甩给你一大堆新闻文本数据它们可能来源于不同的领域比如体育、政治、经济、社会等类型。这时我们就需要对文本分类处理方便用户快速查询自己感兴趣的内容甚至按用户的需要定向推荐某类内容。
这样的需求就非常适合用PyTorch + BERT处理。为什么会选择BERT呢因为BERT是比较典型的深度学习NLP算法模型也是业界使用最广泛的模型之一。接下来我们就一起来搭建这个文本分类模型相信我它的效果表现非常强悍。
## 问题背景与分析
正式动手之前,我们不妨回顾一下历史。文本分类问题有很多经典解决办法。
开始时就是最简单粗暴的关键词统计方法。之后又有了基于贝叶斯概率的分类方法,通过某些条件发生的概率推断某个类别的概率大小,并作为最终分类的决策依据。尽管这个思想很简单,但是意义重大,时至今日,贝叶斯方法仍旧是非常多应用场景下的好选择。
之后还有支持向量机SVM很长一段时间其变体和应用都在NLP算法应用的问题场景下占据统治地位。
随着计算设备性能的提升、新的算法理论的产生等进步一大批的诸如随机森林、LDA主题模型、神经网络等方法纷纷涌现可谓百家争鸣。
既然有这么多方法为什么这里我们这里推荐选用BERT呢
因为在很多情况下尤其是一些复杂场景下的文本像BERT这样具有强大处理能力的工具才能应对。比如说新闻文本就不好分类因为它存在后面这些问题。
1.**类别多**。在新闻资讯App中新闻的种类是非常多的需要产品经理按照统计、实用的原则进行文章分类体系的设计使其类别能够覆盖所有的文本一般来说都有50种甚至以上。不过为了让你把握重点咱们先简化问题假定文本的分类体系已经确定。
2.**数据不平衡**。不难理解,在新闻中,社会、经济、体育、娱乐等类别的文章数量相对来说是比较多的,占据了很大的比例;而少儿、医疗等类别则相对较少,有的时候一天也没有几篇对应的文章。
3.**多语言。**一般来说,咱们主要的语言除了中文,应该是大多数人只会英语了,不过为了考虑到新闻来源的广泛性,咱们也假定这批文本是多语言的。
刚才提到了因为Bert是比较典型的深度学习NLP算法模型也是业界使用最广泛的模型之一。如果拿下这么有代表性的模型以后你学习和使用基于Attention的模型你也能举一反三比如GPT等。
想要用好BERT我们需要先了解它有哪些特点。
## BERT原理与特点分析
BERT的全称是Bidirectional Encoder Representation from Transformers即双向Transformer的Encoder。作为一种基于Attention方法的模型它最开始出现的时候可以说是抢尽了风头在文本分类、自动对话、语义理解等十几项NLP任务上拿到了历史最好成绩。
在[第22节课](https://time.geekbang.org/column/article/461691)如果不熟悉可以回看我们已经了解了Attention的基本原理有了这个知识做基础我们很容易就能快速掌握BERT的原理。
这里我再快速给你回顾一下BERT的理论框架主要是基于论文《Attention is all you need》中提出的Transformer而后者的原理则是刚才提到的Attention。**其最为明显的特点就是摒弃了传统的RNN和CNN逻辑有效解决了NLP中的长期依赖问题。**
![图片](https://static001.geekbang.org/resource/image/57/e7/57129ea84051eaf5985535dcb97c1fe7.jpg?wh=1920x1269 "图片来源https://arxiv.org/abs/1706.03762")
在BERT中它的输入部分也就是图片的左边其实是由N个多头Attention组合而成。多头Attention是将模型分为多个头形成多个子空间可以让模型去关注不同方面的信息这有助于网络捕捉到更丰富的特征或者信息。具体原理一定要查阅[《Attention is all you need》](https://arxiv.org/abs/1706.03762)哦)。
结合上图我们要注意的是BERT采用了基于MLM的模型训练方式即Mask Language Model。因为BERT是Transformer的一部分即encoder环节所以没有decoder的部分其实就是GPT
为了解决这个问题MLM方式应运而生。它的思想也非常简单就是在**训练之前随机将文本中一部分的词语token进行屏蔽mask然后在训练的过程中使用其他没有被屏蔽的token对被屏蔽的token进行预测**。
![图片](https://static001.geekbang.org/resource/image/ae/84/aeed42d94750436f1dyye31f92c96584.jpg?wh=1920x700)
用过Word2Vec的小伙伴应该比较清楚在Word2Vec中对于同一个词语它的向量表示是固定的这也就是为什么会有那个经典的“_国王-男人+女人=皇后_”计算式了。
但是有一个问题“苹果”这个词有可能是水果的苹果也可能是电子产品的品牌如果还是用同一个向量表示这样就有可能产生偏差。而在BERT中则不一样根据上下文的不同对于同一个token给出的词向量是动态变化的更加灵活。
此外BERT还有多语言的优势。在以前的算法中比如SVM如果要做多语言的模型就要涉及分词、提取关键词等操作而这些操作要求你对该语言有所了解。像阿拉伯文、日语等语言咱们大概率是看不懂的这会对我们最后的模型效果产生极大影响。
BERT则不需要担心这个问题通过基于字符、字符片段、单词等不同粒度的token覆盖并作WordPiece能够覆盖上百种语言甚至可以说只要你能够发明出一种逻辑上自洽的语言BERT就能够处理。有关WordPiece的介绍你可以通过[这里](https://paperswithcode.com/method/wordpiece)做拓展阅读。
说了这么多集高效、准确、灵活再加上用途广泛于一体的BERT自然而然就成为了咱们的首选下面咱们开始正式构建一个文本分类模型。
## 安装与准备
工欲善其事,必先利其器,在开始构建模型之前,我们要安装相应的工具,然后下载对应的预先训练好的模型,同时还要了解数据的格式。
### 环境准备
因为咱们要做的是一个基于PyTorch 的BERT模型那么就要安装对应的python包这里我选择的是hugging face的PyTorch版本的Transformers包。你可以通过pip命令直接安装。
```
pip install Transformers
```
### 模型准备
安装之后我们打开Transformers的[git页面](https://github.com/huggingface/transformers),并找到如下的文件夹。
```plain
src/Transformers/models/BERT
```
从这个文件夹里我们需要找到两个很重要的文件分别是convert\_BERT\_original\_tf2\_checkpoint\_to\_PyTorch.py和modeling\_BERT.py文件。
先来看第一个文件你看看名字是不是就能猜出来它大概是用来做什么的了没错就是用来将原来通过TensorfFlow预训练的模型转换为PyTorch的模型。
然后是modeling\_BERT.py文件这个文件实际上是给了你一个使用BERT的范例。
下面,咱们开始准备模型,打开[这个地址](https://github.com/tensorflow/models/tree/master/official/nlp/bert),你会发现在这个页面中,有几个预训练好的模型。
![图片](https://static001.geekbang.org/resource/image/f7/a9/f7429816e9c736d99be4b55c67bac6a9.png?wh=1920x956)
对照这节课的任务我们选择的是“BERT-Base, Multilingual Cased”的版本。从GitHub的介绍可以看出这个版本的checkpoint支持104种语言是不是很厉害当然如果你没有多语言的需求也可以选择其他版本的它们的区别主要是网络的体积不同。
转换完模型之后你会发现你的本地多了三个文件分别是config.json、pytorch\_model.bin和vocab.txt。我来分别给你说一说。
![图片](https://static001.geekbang.org/resource/image/a8/00/a85bfecfb02108cf7e46d5bef74efe00.jpg?wh=1920x228)
1.config.json顾名思义该文件就是BERT模型的配置文件里面记录了所有用于训练的参数设置。
2.PyTorch\_model.bin模型文件本身。
3.vocab.txt词表文件。尽管BERT可以处理一百多种语言但是它仍旧需要词表文件用于识别所支持语言的字符、字符串或者单词。
### 格式准备
现在模型准备好了我们还要看看跟模型匹配的格式。BERT的输入不算复杂但是也需要了解其形式。在训练的时候我们输入的数据不能是直接把词塞到模型里而是要转化成后面这三种向量。
1.**Token embeddings**词向量。这里需要注意的是Token embeddings的第一个开头的token一定得是“\[CLS\]”。\[CLS\]作为整篇文本的语义表示,用于文本分类等任务。
2.**Segment embeddings**。这个向量主要是用来将两句话进行区分,比如问答任务,会有问句和答句同时输入,这就需要一个能够区分两句话的操作。不过在咱们此次的分类任务中,只有一个句子。
3.**Position embeddings**。记录了单词的位置信息。
## 模型构建
准备工作已经一切就绪我们这就来搭建一个基于BERT的文本分类网络模型。这包括了**网络的设计、配置、以及数据准备,这个过程也是咱们的核心过程**。
### 网络设计
从上面提到的modeling\_BERT.py文件中我们可以看到作者实际上已经给我们提供了很多种类的NLP任务的示例代码咱们找到其中的“BERTForSequenceClassification”这个分类网络我们可以直接使用它也是最最基础的BERT文本分类的流程。
这个过程包括了利用**BERT得到文本的embedding表示**、**将embedding放入全连接层得到分类结果**两部分。我们具体看一下代码。
```python
class BERTForSequenceClassification(BERTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels//类别标签数量
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)//还记得Dropout是用来做什么的吗可以一定程度防止过拟合
self.classifier = nn.Linear(config.hidden_size, config.num_labels)//BERT输出的embedding传入一个MLP层做分类
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]//这个就是经过BERT得到的中间输出
pooled_output = self.dropout(pooled_output)//就是为了减少过拟合和增加网络的健壮性
logits = self.classifier(pooled_output)//多层MLP输出最后的分类结果
```
对照前面的代码可以发现接收到输入信息之后BERT返回了一个outputsoutputs包括了模型计算之后的全部结果不仅有每个token的信息也有整个文本的信息这个输出具体包括以下信息。
last\_hidden\_state是模型最后一层输出的隐藏层状态序列。shape是(batch\_size, sequence\_length, hidden\_size)。其中hidden\_size=768这个部分的状态就相当于利用sequence\_length \* 768维度的矩阵记录了整个文本的计算之后的每一个token的结果信息。
pooled\_output代表序列的第一个token的最后一个隐藏层的状态。shape是(batch\_size, hidden\_size)。所谓的第一个token就是咱们刚才提到的\[CLS\]标签。
除了上面两个信息还有hidden\_states、attentions、cross attentions。有兴趣的小伙伴可以去查一下它们有何用途。
通常的任务中我们用得比较多的是last\_hidden\_state对应的信息我们可以用pooled\_output = outputs\[1\]来进行获取。
至此我们已经有了经过BERT计算的文本向量表示然后我们将其输入到一个linear层中进行分类就可以得到最后的分类结果了。**为了提高模型的表现我们往往会在linear层之前加入一个dropout层这样可以减少网络的过拟合的可能性同时增强神经元的独立性**。
### 模型配置
设计好网络我们还要对模型进行配置。还记得刚才提到的config.json文件么这里面就记录了BERT模型所需的所有配置信息我们需要对其中的几个内容进行调整这样模型就能知道我们到底是要做什么事情了。
后面这几个字段我专门说一下。
* id2label这个字段记录了类别标签和类别名称的映射关系。
* label2id这个字段记录了类别名称和类别标签的映射关系。
* num\_labels\_cate类别的数量。
## 数据准备
模型网络设计好了配置文件也搞定了下面我们就要开始数据准备这一步了。这里的数据准备是指将文本转换为BERT能够识别的形式即前面提到的三种向量在代码中对应的就是input\_ids、token\_type\_ids、attention\_mask。
为了生成这些数据我们需要在git中找到“src/Transformers/data/processors/utils.py”文件在这个文件中我们要用到以下几个内容。
1.InputExample它用于记录单个训练数据的文本内容的结构。
2.DataProcessor通过这个类中的函数我们可以将训练数据集的文本表示为多个InputExample组成的数据集合。
3.get\_features用于把InputExample数据转换成BERT能够理解的数据结构的关键函数。我们具体来看一下各个数据都怎么生成的。
input\_ids记录了输入token对应在vocab.txt的id序号它是通过如下的代码得到的。
```python
input_ids = tokenizer.encode(
example.text_a,
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
)
```
而attention\_mask记录了属于第一个句子的token信息通过如下代码得到。
```python
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
```
另外不要忘记记录文本类别的信息label。你可以自己想想看能否按照utils.py文件中的声明方式构建出对应的label信息呢
## 模型训练
到目前为止我们有了网络结构定义BERTForSequenceClassification、数据集合get\_features现在就可以开始编写实现训练过程的代码了。
### 选择优化器
首先我们来选择优化器代码如下。我们要对网络中的所有权重参数进行设置这样优化器就可以知道哪些参数是要进行优化的。然后我们将参数list放到优化器中BERT使用的是AdamW优化器。
```plain
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
```
这部分的代码,主要是为了选择一个合适咱们模型任务的优化器,并将网络中的参数设定好学习率。
### 构建训练过程逻辑
训练的过程逻辑是非常简单的只需要两个for循环分别代表epoch和batch然后在最内部增加**一个训练核心语句,**以及**一个梯度更新语句**这就足够了。可以看到PyTorch在工程代码的实现上封装得非常完善和简练。
```plain
for epoch in trange(0, args.num_train_epochs):
model.train()//一定别忘了要把模型设置为训练状态。
for step, batch in enumerate(tqdm(train_dataLoader, desc='Iteration')):
step_loss = training_step(batch)//训练的核心环节
tr_loss += step_loss[0]
optimizer.step()
optimizer.zero_grad()
```
### 训练的核心环节
训练的核心环节,你需要关注两个部分,分别是**通过网络得到预测输出**也就是logits以及**基于logits计算得到的loss**loss是整个模型使用梯度更新需要用到的数据。
```plain
def training_step(batch):
input_ids, token_type_ids, attention_mask, labels = batch
input_ids = input_ids.to(device)//将数据发送到GPU
token_type_ids = token_type_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels_voc.to(device)
        
logits = model(input_ids,
token_type_ids=token_type_ids, 
attention_mask=attention_mask, 
labels=labels)
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, num_labels_cate), labels.view(-1, num_labels_cate).float())
loss.backward()
```
至此咱们已经快速构建出了一个BERT分类器所需的所有关键代码。但是仍旧有一些小小的环节需要你来完善比如training\_step代码块中的device是怎么得到的呢回顾一下咱们之前学习的内容相信你一定可以做得到。
## 小结
恭喜你完成了这节课的学习尽管现在GitHub上已经有了很多已经封装得非常完善的BERT代码你也可以很快实现一个最基本的NLP算法流程但是我仍希望你能够抽出时间好好看一下Transformer中的模型代码这会对你的技术提升有非常大的助益。
这节课我们学习了如何用PyTorch快速构建一个基本的文本分类模型想要实现这个过程你需要了解BERT的预训练模型的获取以及转化、分类网络的设计方法、训练过程的编写。整个过程不难但是却可以让你快速上手了解PyTorch在NLP方面如何应用。
除了技术本身,业务方面的考虑我们也要注意。比如新闻文本的多语言、数据不平衡等问题,模型有时不能解决所有的问题,因此你还需要学习一些**数据预处理的技巧**,这包括很多技术和算法方面的内容。
即使我列出一份长长的学习清单也可能会挂一漏万所以数据预处理方面的知识我建议你重点关注以下内容建议你需要花一些时间去学习NumPy和Pandas的使用这样才能更加得心应手地处理数据你还可以多学习一些常见的数据挖掘算法比如决策树、KNN、支持向量机等另外深度学习的广泛使用其实仍旧非常需要传统机器学习算法的背后支撑也建议你多多了解。
## 思考题
BERT处理文本是有最大长度要求的512那么遇到长文本该怎么办呢
也欢迎你在留言区记录你的疑问或者收获,也推荐你把这节课分享给你的朋友。