1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| import matplotlib.pyplot as plt import numpy as np
def visualize_attention(attention_weights, tokens, layer=0, head=0, max_display_tokens=100, figsize_scale=0.2): """ 可视化注意力权重(优化大维度显示) 参数: attention_weights: 注意力权重张量 tokens: 文本token列表 layer: 要可视化的层 head: 要可视化的注意力头 max_display_tokens: 最大显示的token数量(超过则截断) figsize_scale: 每个token分配的图像尺寸比例(控制整体大小) """ attn = attention_weights[layer][0, head].detach().numpy() n_tokens = len(tokens) display_tokens = tokens[:max_display_tokens] display_attn = attn[:max_display_tokens, :max_display_tokens] display_len = len(display_tokens) figsize = (int(display_len * figsize_scale), int(display_len * figsize_scale)) figsize = (min(figsize[0], 30), min(figsize[1], 30)) plt.figure(figsize=figsize) plt.imshow(display_attn, cmap='hot', interpolation='nearest') plt.colorbar(shrink=0.8) step = max(1, display_len // 20) tick_positions = range(0, display_len, step) tick_labels = [display_tokens[i] for i in tick_positions] plt.xticks(tick_positions, tick_labels, rotation=90, fontsize=6) plt.yticks(tick_positions, tick_labels, fontsize=6) plt.title(f"Attention Weights - Layer {layer}, Head {head}", fontsize=10) plt.tight_layout() plt.show()
tokens = tokenizer.convert_ids_to_tokens(encoded_text['input_ids'][0])
visualize_attention(outputs.attentions, tokens, layer=0, head=0, max_display_tokens=150, figsize_scale=0.25)
|