使用LoRA微调Llama-2-7b-hf实现涉诈短信识别

本博客为2024挑战杯项目基于大模型的多模态风险内容识别系统的涉诈短信识别功能的实现。

方案选择

Huggingface格式LLama模型+Lora代码微调

环境准备

GPU服务器:RTX 4090,24G双GPU,cuda12

Python: 3.11

由于40系GPU不支持某些高效的通信模式,需要设置环境变量:

1
2
export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1

模型准备

模型下载

下载Llama-2-7b-hf模型,使用的是Llama中文社区整理的模型资源。

LlamaFamily/Llama-Chinese: Llama中文社区,实时汇总最新Llama学习资料,构建最好的中文Llama大模型开源生态,完全开源可商用

模型验证

可以用以下代码测试下载的模型的效果,注意修改模型保存的路径,此处为/home/data/pre_model/Llama-2-7b-hf。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_path = "/home/data/pre_model/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # 自动分配GPU资源
).eval() # 启用评估模式提升推理速度

input_text = "How to learn skiing?"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=256,
do_sample=True, # 启用采样生成更自然文本
temperature=0.7,
top_p=0.9
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

输出如下,可以看出生成的文本比较流畅。

1
2
3
4
5
6
How to learn skiing?
Skiing is an exciting and fun winter activity that many people love. While skiing can be challenging at first, with the right instruction and practice, anyone can learn how to ski.
Learning to ski is a process that requires patience and practice. It is important to start with the basics, such as learning how to balance on skis, and progress gradually to more advanced techniques.
The best way to learn how to ski is to take lessons from a qualified instructor. A qualified instructor will be able to teach you the basics of skiing, such as balance, turning, and stopping. They will also be able to teach you more advanced techniques, such as carving and jumping.
Another way to learn how to ski is to practice on a ski slope. Ski slopes are designed to help you learn how to ski safely and effectively. They are usually divided into different levels, so you can start on a beginner slope and gradually progress to more challenging slopes.
It is also important to wear the right equipment when learning how to ski. This includes a helmet, goggles, and warm clothing. Wearing the right

LoRA微调数据集准备

使用ChangMianRen/Telecom_Fraud_Texts_5,其中包含了大量经过标记的诈骗短信和正常短信样本。

将数据进行预处理,得到符合LoRA微调格式的数据集。

原始数据整理为形如:

content label
最后小时,在微信添加朋友中输入良品铺子美食旅行关注参与活动并抢最高DIGIT元红包。如需退订请回复TD或直接退出良品铺子的公众号即可! 0
你好,我是贷款公司的代表。你是否有资金需求?我们提供低利率、快速审批的贷款服务。如果你感兴趣的话请添加我的微信号:xxxxxxxxx。 1
你好,是满梦园吗?我这里是公安机关的民警。我们发现您的身份信息可能被泄露了,涉嫌诈骗活动。我们需要您协助调查此事。请下载我们的”teams”app并与我们在上面进行交流。谢谢配合! 1

应用的模版为:

1
2
3
4
5
6
7
8
9
10
11
12
13
"""
### Instruction:
你是一个专门识别诈骗短信的专家,请判断输入的短信是否是诈骗短信,如果是,请回答True,否则回答False。
诈骗短信一般具有以下特征:
1. 诱导点击链接或拨打电话或添加微信
2. 内容涉及赌博、中奖、钱财等
3. 使用特殊符号或文字,或使用符号隔断文字
4. 使用黑话/暗语,令人难以理解

### Input:{}

### Response:{}
</s>"""

将数据中的content作为input,label为1时Response为True,为0时Response为False。

微调代码

训练器的参数意义可以参考huggingface transformers使用指南之二——方便的trainer - 知乎

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer,SFTConfig
from torch.utils.data import Dataset
import pandas as pd

class SMSDataset(Dataset):
def __init__(self, data_path):
self.data = pd.read_csv(data_path)
self.prompt_template = """
### Instruction:
你是一个专门识别诈骗短信的专家,请判断输入的短信是否是诈骗短信,如果是,请回答True,否则回答False。
诈骗短信一般具有以下特征:
1. 诱导点击链接或拨打电话或添加微信
2. 内容涉及赌博、中奖、钱财等
3. 使用特殊符号或文字,或使用符号隔断文字
4. 使用黑话/暗语,令人难以理解

### Input:{}

### Response:{}
</s>"""

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
raw_data = self.data.iloc[idx]
prompt_data=self.prompt_template.format(raw_data['content'],"True" if raw_data['label']==1 else "False")
prompt_data=tokenizer(prompt_data)
return prompt_data

SMStrainDataset = SMSDataset("./train.csv")
SMSvalidDataset = SMSDataset("./valid.csv")


model_path = "/home/data/pre_model/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
# load_in_8bit=True
)
model.enable_input_require_grads()
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token


lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=16,
lora_dropout=0.1,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

training_args = SFTConfig(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps = 5,
num_train_epochs = 1,
gradient_checkpointing=True,
#max_steps = 60,
learning_rate = 2e-4,
optim = "adamw_torch",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "./results",
report_to = "none",
max_seq_length = 512,
dataset_num_proc = 4,
packing = False,
)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=SMStrainDataset,
eval_dataset=SMSvalidDataset,
peft_config=lora_config,
)

trainer.train()

model.save_pretrained('./lora_model')

值得注意的是,过程中出现了张量不在同一设备的情况,经过检查,在transformers库的loss_utils.py文件内的

ForCausalLMLoss函数内增加

1
num_items_in_batch=num_items_in_batch.to(logits.device)

解决了设备不同的问题。

效果验证

构造测试脚本进行测试,取模型输出的前五个字符作为判断结果

1
2
3
4
5
rsp=output[len(input_text):].strip()
if "True" in rsp[:5] and label==True:
current+=1
elif "False" in rsp[:5] and label==False:
current+=1

对比原始模型和微调后模型结果如下:

指标 原始模型 微调后模型
准确率 0.180 0.977
F1分数 0.294 0.968