使用Pytorch手写Qwen3-0.6B推理

动机

为了理解并编写引导向量提取代码,需要先深入理解模型架构及推理细节,刚好阅读到一篇分享使用Pytorch手写Qwen3-0.6B推理,遂以此为基础开展实践。

步骤

1 模型下载

在镜像站下载Qwen/Qwen3-0.6B at main

2 模型内容

  • lm_head.weight:torch.Size([151936, 1024])

    一个lm_head负责输出,把模型内部状态映射为词汇表上的概率

  • model.embed_tokens.weight:torch.Size([151936, 1024])

    处理词嵌入,把token转为张量,矩阵shape为[vocab_size, hidden_size]

  • model.layers.[n].input_layernorm.weight:torch.Size([1024])

    把输入参数做归一化,防止梯度爆炸,shape为[hidden_size]

  • model.layers.[n].mlp.down_proj.weight:torch.Size([1024, 3072])

    线性变换

  • model.layers.[n].mlp.gate_proj.weight:torch.Size([3072, 1024])

    线性变换

  • model.layers.[n].mlp.up_proj.weight:torch.Size([3072, 1024])

    线性变换

  • model.layers.[n].post_attention_layernorm.weight:torch.Size([1024])

    注意力之后,FNN之前的归一化,shape为[hidden_size]

  • model.layers.[n].self_attn.q_proj.weight:torch.Size([2048, 1024])

    把token的隐藏向量站换为Q向量,有16头,每2个Q共享一组KV

  • model.layers.[n].self_attn.k_proj.weight:torch.Size([1024, 1024])

    把token的隐藏向量站换为K向量,有8头

  • model.layers.[n].self_attn.v_proj.weight:torch.Size([1024, 1024])

    把token的隐藏向量站换为V向量,有8头

  • model.layers.[n].self_attn.o_proj.weight:torch.Size([1024, 2048])

    输出投影,把softmax计算出来的结果投影回原来的维度

  • model.layers.[n].self_attn.q_norm.weight:torch.Size([128])

    线性变换,对Q向量每个头内部的 head_dim 维度进行归一化,稳定注意力分数的分布

  • model.layers.[n].self_attn.k_norm.weight:torch.Size([128])

    线性变换,对K向量每个头内部的 head_dim 维度进行归一化,稳定注意力分数的分布

  • model.norm.weight:torch.Size([1024])

    归一化:线性变换,shape是[hidden_size]

3 模型架构

img

图片

1247f69281e093613a9ca35d5b0d33ce

4 代码实现

  • 对比直接使用transformers库
1
2
3
4
5
6
7
8
9
10
11
from transformers import AutoModelForCausalLM, AutoTokenizer
# transformers加载Qwen0.6B
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name,cache_dir="/home/data/sfx")
model = AutoModelForCausalLM.from_pretrained(model_name,cache_dir="/home/data/sfx").to("cuda")

# 直接推理
message="<|im_start|>user明天做点啥<|im_end|><|im_start|>assistant"
inputs = tokenizer(message, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=3000, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  • 定义Config类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import nn
import torch.nn.functional as F

class SelfQwen3Config:
def __init__(self,config_dict=None):
for key, value in config_dict.items():
setattr(self, key, value)

with open("./Qwen0.6BFiles/config.json", "r") as f:
import json
config_dict = json.load(f)

config = SelfQwen3Config(config_dict)

print("DONE")
  • 定义模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def apply_rotary_pos_emb(q, k, position_ids, head_dim, rope_theta=1000000.0):
device = q.device
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim))
freqs = position_ids.unsqueeze(-1).float() * inv_freq.unsqueeze(0).unsqueeze(0)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().unsqueeze(1).to(q.dtype)
sin = emb.sin().unsqueeze(1).to(q.dtype)

def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

class MLP(nn.Module):
def __init__(self,config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

def forward(self, x):
ret=self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return ret

class SelfQwen3RMSNorm(nn.Module):
def __init__(self, hidden_size,eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.q_norm = SelfQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = SelfQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.rope_theta = config.rope_theta

def forward(self, hidden_states, position_ids=None, attention_mask=None):
bsz, q_len, _ = hidden_states.size()
# 得到QKV
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# QK归一化
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
# 应用位置编码
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, position_ids, self.head_dim, self.rope_theta)
# 多头分组查询
if self.num_key_value_groups > 1:
# head_dim维度扩大
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
# 计算注意力得分
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# 计算权重  
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
# 加权求和
attn_output = torch.matmul(attn_weights, value_states)

# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
# 输出投影
attn_output = self.o_proj(attn_output)
return attn_output

class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = Attention(config)
self.post_attention_layernorm = SelfQwen3RMSNorm(config.hidden_size,config.rms_norm_eps)
self.mlp=MLP(config)
self.input_layernorm=SelfQwen3RMSNorm(config.hidden_size,config.rms_norm_eps)

self.hook_attn_output = HookPoint()

def forward(self, hidden_states,position_ids=None, attention_mask=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output = self.self_attn(hidden_states, position_ids, attention_mask)

attn_output = self.hook_attn_output(attn_output)

hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

class SelfQwen3Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = SelfQwen3RMSNorm(config.hidden_size,config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size,bias=False)
self.lm_head.weight = self.embed_tokens.weight

def forward(self, input_ids):
bsz,seq_len=input_ids.shape
position_ids=torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), dtype=torch.float32, device=input_ids.device),
diagonal=1
)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=causal_mask)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits
  • 推理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from safetensors import safe_open
from tokenizers import Tokenizer

model = SelfQwen3Model(config)
new_state_dict = {}
with safe_open("Qwen0.6BFiles/model.safetensors", framework="pt") as f:
for k in f.keys():
v = f.get_tensor(k)
if k.startswith("model."):
new_key = k[len("model."):]
new_state_dict[new_key] = v
else:
new_state_dict[k] = v

model.load_state_dict(new_state_dict,strict=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

#分词 得到词向量
tokenizer = Tokenizer.from_file(str("Qwen0.6BFiles/tokenizer.json"))
message="<|im_start|>user明天做点啥<|im_end|><|im_start|>assistant"
input_ids =tokenizer.encode(message).ids
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)

#推理
with torch.no_grad():
cnt=0
while True:
logits = model(input_ids)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
if next_token.item() == 151645:
break
input_ids = torch.cat([input_ids, next_token], dim=1)
cnt += 1
if cnt > 3000:
break

#输出
output_text = tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)
print(output_text)

5 钩子函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class HookPoint(nn.Module):
def __init__(self, hook_fn=None):
super().__init__()
self.value = None
self.handle = None
self.custom_hook_fn = hook_fn
self._override_enabled = False

def default_hook_fn(self, module, input, output):
self.value = output.detach()

if self._override_enabled:
neg_mask = torch.where(
torch.rand_like(output) < 0.4,
torch.tensor(-1.0, device=output.device, dtype=output.dtype),
torch.tensor(1.0, device=output.device, dtype=output.dtype),
)
return output * neg_mask
else:
return output

def register_hook(self, target_module: nn.Module):
if self.handle is not None:
self.handle.remove()
hook_to_use = self.custom_hook_fn if self.custom_hook_fn else self.default_hook_fn
self.handle = target_module.register_forward_hook(hook_to_use)

def remove(self):
if self.handle is not None:
self.handle.remove()
self.handle = None

def set_override(self):
self._override_enabled = True

def clear_override(self):
self._override_enabled = False

def get_value(self):
return self.value

def forward(self, x):
return x

验证钩子函数生效

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from safetensors import safe_open
from tokenizers import Tokenizer
# 使用钩子捕获

model = SelfQwen3Model(config)
new_state_dict = {}
with safe_open("Qwen0.6BFiles/model.safetensors", framework="pt") as f:
for k in f.keys():
v = f.get_tensor(k)
if k.startswith("model."):
new_key = k[len("model."):]
new_state_dict[new_key] = v
else:
new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=True)

torch.random.manual_seed(42)
target_layer = model.layers[18]
target_module = target_layer.hook_attn_output
hook = HookPoint()
hook.set_override() # 开启 override 模式
hook.register_hook(target_module)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

tokenizer = Tokenizer.from_file(str("Qwen0.6BFiles/tokenizer.json"))
message = "<|im_start|>user明天做点啥<|im_end|><|im_start|>assistant"
input_ids_list = tokenizer.encode(message).ids
input_ids = torch.tensor([input_ids_list], dtype=torch.long, device=device)

with torch.no_grad():
cnt=0
while True:
logits = model(input_ids)
captured_value = hook.get_value()
# print("Captured value shape:", captured_value.shape)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
if next_token.item() == 151645:
break
input_ids = torch.cat([input_ids, next_token], dim=1)
cnt += 1
if cnt > 3000:
break

#输出
output_text = tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)
print(output_text)

此时输出的内容有明显质量退化,说明破坏该层的注意力机制影响了模型推理