使用LoRA微调Llama-2-7b-hf实现涉诈短信识别
本博客为2024挑战杯项目基于大模型的多模态风险内容识别系统的涉诈短信识别功能的实现。
方案选择
Huggingface格式LLama模型+Lora代码微调
环境准备
GPU服务器:RTX 4090,24G双GPU,cuda12
Python: 3.11
由于40系GPU不支持某些高效的通信模式,需要设置环境变量:
1 | export NCCL_P2P_DISABLE=1 |
模型准备
模型下载
下载Llama-2-7b-hf模型,使用的是Llama中文社区整理的模型资源。
LlamaFamily/Llama-Chinese: Llama中文社区,实时汇总最新Llama学习资料,构建最好的中文Llama大模型开源生态,完全开源可商用
模型验证
可以用以下代码测试下载的模型的效果,注意修改模型保存的路径,此处为/home/data/pre_model/Llama-2-7b-hf。
1 | from transformers import AutoTokenizer, AutoModelForCausalLM |
输出如下,可以看出生成的文本比较流畅。
1 | How to learn skiing? |
LoRA微调数据集准备
使用ChangMianRen/Telecom_Fraud_Texts_5,其中包含了大量经过标记的诈骗短信和正常短信样本。
将数据进行预处理,得到符合LoRA微调格式的数据集。
原始数据整理为形如:
| content | label |
|---|---|
| 最后小时,在微信添加朋友中输入良品铺子美食旅行关注参与活动并抢最高DIGIT元红包。如需退订请回复TD或直接退出良品铺子的公众号即可! | 0 |
| 你好,我是贷款公司的代表。你是否有资金需求?我们提供低利率、快速审批的贷款服务。如果你感兴趣的话请添加我的微信号:xxxxxxxxx。 | 1 |
| 你好,是满梦园吗?我这里是公安机关的民警。我们发现您的身份信息可能被泄露了,涉嫌诈骗活动。我们需要您协助调查此事。请下载我们的”teams”app并与我们在上面进行交流。谢谢配合! | 1 |
应用的模版为:
1 | """ |
将数据中的content作为input,label为1时Response为True,为0时Response为False。
微调代码
训练器的参数意义可以参考huggingface transformers使用指南之二——方便的trainer - 知乎
1 | from peft import get_peft_model, LoraConfig, TaskType |
值得注意的是,过程中出现了张量不在同一设备的情况,经过检查,在transformers库的loss_utils.py文件内的
ForCausalLMLoss函数内增加
1 | num_items_in_batch=num_items_in_batch.to(logits.device) |
解决了设备不同的问题。
效果验证
构造测试脚本进行测试,取模型输出的前五个字符作为判断结果
1 | rsp=output[len(input_text):].strip() |
对比原始模型和微调后模型结果如下:
| 指标 | 原始模型 | 微调后模型 |
|---|---|---|
| 准确率 | 0.180 | 0.977 |
| F1分数 | 0.294 | 0.968 |