几种常用的相对位置编码

本文详细介绍了相对位置编码在机器学习和深度学习中的使用,特别是自然语言处理领域。从经典式到XLNET式、T5式、DeBERTa式,逐一阐述各种方法的原理和差异,并提及FLAT模型在中文NER中的应用,强调了位置编码在模型中的重要性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

主要参考苏神blog,传送门,自己整理下加深记忆

绝对位置编码

绝对位置编码一般形式的带绝对位置编码的Attention形式如下:
{ q i = ( x i + p i ) W Q k i = ( x i + k i ) W K v i = ( x i + v i ) W V a i , j = s o f t m a x ( q i k j T ) o i = ∑ j a i , j v j \left\{ \begin{aligned} q_i & = &(x_i+p_i)W_Q \\ k_i & = & (x_i+k_i)W_K\\ v_i & = & (x_i+v_i)W_V\\ a_{i,j} & = & softmax(q_ik_{j}^{T})\\ o_i & = &\sum_ja_{i,j}v_j \end{aligned} \right. qikiviai,joi=====(xi+pi)WQ(xi+ki)WK(xi+vi)WVsoftmax(qikjT)jai,jvj
其中softmax表示对 j j j那一维归一化,这里的向量都是指行向量。将 q i k j T q_ik_j^T qikjT展开可得:
q i k j T = ( x i + p i ) W Q W K T ( x j + p j ) T = ( x i W Q + p i W Q ) ( W K T x j T + W K T p j T ) q_ik_j^T=(x_i+p_i)W_QW_K^T(x_j+pj)^T=(x_iW_Q+p_iW_Q)(W_K^Tx_j^T+W_K^Tp_j^T) qikjT=(xi+pi)WQWKT(xj+pj)T=(xiWQ+piWQ)(WKTxjT+WKTpjT)

相对位置编码

经典式

相对位置编码有绝对位置编码启发而来,相对位置编码起源于Google的论文《Self-Attention with Relative Position Representations》,华为的NEZHA也用了这种位置编码。Google把第一项位置去掉,第二项 p j W K p_jW_K pjWK改为二元位置向量 R i , j K R_{i,j}^K Ri,jK,即:
a i , j = s o f t m a x ( x i W Q ( x j W K + R i , j K ) T ) a_{i,j}=softmax(x_iW_Q(x_jW_K+R_{i,j}^K)^T) ai,j=softmax(xiWQ(xjWK+Ri,jK)T)
o i o_i oi的计算方式变为:
o i = ∑ j a i , j ( x j W K + R i , j V ) o_i = \sum_ja_{i,j}(x_jW_K+R_{i,j}^V) oi=jai,j(xjWK+Ri,jV)

XLNET式

起源于Transformer-XL的论文《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》。将绝对位置编码中的 q i k j T q_ik_j^T qikjT完全展开可得到:
x i W Q W K T x j T + x i W Q W K T p j T + p i W Q W K T x j T + p i W Q W K T p j T x_iW_QW_K^Tx_j^T+x_iW_QW_K^Tp_j^T+p_iW_QW_K^Tx_j^T+p_iW_QW_K^Tp_j^T xiWQWKTxjT+xiWQWKTpjT+piWQWKTxjT+piWQWKTpjT
于是transformer-XL将 p j p_j pj替换为相对位置向量 R i − j T R_{i-j}^T RijT,两个 p i p_i pi替换成可训练的向量 u , v u,v u,v:
x i W Q W K T x j T + x i W Q W K T R i − j T + u W Q W K T x j T + v W Q W K T R i − j T x_iW_QW_K^Tx_j^T+x_iW_QW_K^TR_{i-j}^T+uW_QW_K^Tx_j^T+vW_QW_K^TR_{i-j}^T xiWQWKTxjT+xiWQWKTRijT+uWQWKTxjT+vWQWKTRijT

并且直接去掉了 v j v_j vj的位置偏置,直接令 o i = ∑ j a i , j x j W V o_i=\sum_ja_{i,j}x_jW_V oi=jai,jxjWV

T5式

来源于文章《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》,该篇文章认为输入纤细与位置信息应该是独立(解耦)的,那么它们就不应该有过多的交互,所以“输入-位置”,“位置-输入”两项的Attention可以删掉,而 p i W Q W K T p j T p_iW_QW_K^Tp_j^T piWQWKTpjT只是一个依赖于 ( i , j ) (i,j) (i,j)的标量,可以直接将它作为参数训练出来,即简化为:
x i W Q W K T x J t + β i , j x_iW_QW_K^Tx_J^t+\beta _{i,j} xiWQWKTxJt+βi,j,与XL一样在 v j v_j vj上的位置偏置直接被去掉。
但与一般不同的是,不同于常规位置编码对将 β i , j \beta _{i,j} βi,j视为 i − j i-j ij的函数进行截断的做法,T5对相对位置进行了“分桶”处理,即相对位置是 i − j i-j ij的实际上对应的是 f ( i − j ) f(i-j) f(ij)的位置,映射关系如下:
在这里插入图片描述

DeBERTa式

起源于《DeBERTa: Decoding-enhanced BERT with Disentangled Attention》,T5是剔除掉第2和第3项,而DeBERTa则去掉第四项:
q i k j T = x i W Q W K T x j T + x i W Q W K T R i , j T + R j , i W Q W K T x j T q_ik_j^T=x_iW_QW_K^Tx_j^T+x_iW_QW_K^TR_{i,j}^T+R_{j,i}W_QW_K^Tx_j^T qikjT=xiWQWKTxjT+xiWQWKTRi,jT+Rj,iWQWKTxjT

FLAT: Chinese NER Using Flat-Lattice Transformer(ACL2020)

FLAT也采用相对位置编码,attention矩阵的计算方式如下:
A i , j ∗ = E x i W q W k T E x j T + E x i W q W k , R T R i , j T + u W k , E T E x j T + v W k , R T R i , j T A_{i,j}^{*}=E_{x_i}W_qW_{k}^{T}E_{x_j}^{T} + E_{x_i}W_{q}W_{k,R}^{T}R_{i,j}^{T} + uW_{k,E}^{T}E_{x_j}^{T} + vW_{k,R}^{T}R_{i,j}^{T} Ai,j=ExiWqWkTExjT+ExiWqWk,RTRi,jT+uWk,ETExjT+vWk,RTRi,jT
并提出了四中相对距离的表示方式,同时考虑字符和词之间的关系:
{ d i j h h = h e a d [ i ] − h e a d [ j ] d i j h t = h e a d [ i ] − t a i l [ j ] d i j t h = t a i l [ i ] − h e a d [ j ] d i j t t = t a i l [ i ] − t a i l [ j ] \left\{ \begin{aligned} d_{ij}^{hh} & = & head[i]-head[j] \\ d_{ij}^{ht} & = & head[i]-tail[j]\\ d_{ij}^{th} & = & tail[i]-head[j]\\ d_{ij}^{tt} & = & tail[i]-tail[j] \end{aligned} \right. dijhhdijhtdijthdijtt====head[i]head[j]head[i]tail[j]tail[i]head[j]tail[i]tail[j]
相对位置的encoding为:
在这里插入图片描述
的计算方式与vanilla Transformer相同。

### Transformer 中相对位置编码的改进方法 #### 背景介绍 Transformer 模型最初采用的是绝对位置编码 (Absolute Positional Encoding),即通过固定的位置索引来表示输入序列中的每个词。然而,在处理长序列时,绝对位置编码可能无法充分捕捉远距离上下文关系。因此,研究者提出了多种 **相对位置编码** 方法来解决这一问题。 #### 相对位置编码的基本概念 相对位置编码的核心思想在于利用序列中元素间的相对偏移量而非固定的绝对位置来进行建模[^1]。这种方法能够更好地捕获长程依赖关系,并减少因序列长度增加而导致的信息丢失风险。 #### 改进方法概述 以下是几种常见的相对位置编码改进方法及其对应的实现方式: --- #### 1. **基于注意力机制的相对位置编码** 该方法通过对自注意力机制进行扩展,将相对位置信息融入到注意力权重计算过程中。具体而言,对于任意两个 token \(i\) 和 \(j\),定义其相对位置为 \(\text{rel}_{ij} = j - i\)。然后将其映射至一个可学习的嵌入矩阵中。 ##### 实现细节 假设我们有一个大小为 \(d_{\text{model}}\) 的隐藏层维度,则可以构建如下形式的相对位置编码: \[ E_{\text{rel}}(k) = f(k), \quad k \in [-K, K], \] 其中 \(f(k)\) 表示某种函数(如正弦/余弦波形),\(K\) 是最大允许的距离范围。 在实际应用中,可以通过以下代码片段展示如何生成这些相对位置向量: ```python import torch import math def generate_relative_positions(max_len=512, d_model=512): positions = torch.arange(-max_len, max_len).float() / max_len * d_model embeddings = [] for pos in positions: embedding = [math.sin(pos / (10000**(2*i/d_model))) if i % 2 == 0 else \ math.cos(pos / (10000**(2*(i-1)/d_model))) for i in range(d_model)] embeddings.append(embedding) return torch.tensor(embeddings) relative_pos_embeddings = generate_relative_positions() print(relative_pos_embeddings.shape) # 输出形状应为 (2*max_len+1, d_model) ``` 此部分涉及的内容主要来源于早期关于 Transformer-XL 的工作。 --- #### 2. **旋转位置编码 (Rotary Position Embedding)** 另一种有效的改进方案称为“旋转位置编码”,它通过周期性的三角函数变换实现了动态调整的能力[^2]。相比于静态的绝对或相对位置编码,这种方式具有更强的表现力以及更少的记忆负担。 ##### 数学描述 给定一对 query (\(q_i\)) 及 key (\(k_j\)) 向量,我们可以分别对其进行分解并施加特定角度的旋转变换操作: \[ q'_i[k] = \begin{cases} q_i[k]\cos{\theta_k}, & \text{(even)} \\ -q_i[k]\sin{\theta_k}, & \text{(odd)} \end{cases}, \] 同理适用于 keys 的转换过程。这里 \(\theta_k=\frac{1}{T^{2k/d}}\) 定义了一个随频率变化的角度参数表。 上述逻辑可以用 PyTorch 编写成如下模块化结构: ```python class RotaryPositionEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.dim = dim inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) @staticmethod def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat((-x2, x1), dim=-1) def forward(self, seq_len): t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) # shape: [seq_len, dim] cos_emb, sin_emb = emb.cos(), emb.sin() return cos_emb[:, None, :], sin_emb[:, None, :] # Add batch dimension rot_pe = RotaryPositionEmbedding(dim=64) cosines, sines = rot_pe(seq_len=8) print(cosines.shape, sines.shape) # 应输出 [(8, 1, 64)] twice. ``` 这部分技术通常被用于视觉领域内的 ViTs 或 NLP 领域的大规模预训练模型当中。 --- #### 3. **乘积法 (Product Method)** 针对传统交叉法存在的不足之处——比如容易混淆不同方向上的相似间隔情况等问题,有学者提出了一种名为 “乘积法” 的新型策略[^4]。它的核心优势体现在两方面:一是显著降低了存储需求;二是增强了区分度较高的多维特征表达能力。 简单来说,就是把原本独立的一维坐标替换成了由多个因子共同决定的新体系。例如在一个二维平面上,假设有水平轴 L 和垂直轴 V ,那么对应某个点 P(i,j) 就能写作: \[ e_P[i][j]=e_L[i]*e_V[j]. \] 下面是一个简化版的例子演示如何创建这样的表格数据集: ```python def product_method_encoding(height=16, width=16, embed_dim=32): assert embed_dim % 2 == 0, 'Embedding size must be even.' half_d = embed_dim // 2 row_embs = torch.zeros([height, half_d]) col_embs = torch.zeros([width, half_d]) for h in range(height): row_embs[h] = [ math.sin(h/(10000**(r/half_d))) if r%2==0 else math.cos(h/(10000**((r-1)/half_d))) for r in range(half_d)] for w in range(width): col_embs[w] = [ math.sin(w/(10000**(c/half_d))) if c%2==0 else math.cos(w/(10000**((c-1)/half_d))) for c in range(half_d)] combined = torch.matmul(row_embs.unsqueeze(2), col_embs.T.unsqueeze(0)).view(height*width,-1) return combined encoded_matrix = product_method_encoding() print(encoded_matrix.size()) # Expected output: torch.Size([256, 32]) when height=width=16 and embed_dim=32 ``` 以上算法特别适合应用于图像分类或者目标检测任务之中。 --- ### 总结 综上所述,目前主流的三种相对位置编码改进方法分别为基于注意力机制的方法、旋转位置编码以及乘积法。每种都有各自的特点和适用场景,开发者可以根据具体的项目需求灵活选用合适的解决方案。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值