2023/12/07 11:34

# 一、 LLaMA 的模型结构

## 1.1. RMSNorm 归一化函数

RMSNorm 在HuggingFace Transformer 库中代码实现如下所示：

class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps # eps 防止取倒数之后分母为0
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# weight 是末尾乘的可训练参数, 即g_i
return (self.weight * hidden_states).to(input_dtype)

## 1.2. SwiGLU 激活函数

SwiGLU[50] 激活函数是Shazeer 在文献中提出，并在PaLM等模中进行了广泛应用，并且取得了不错的效果，相较于ReLU 函数在大部分评测中都有不少提升。在LLaMA 中全连接层使用带有SwiGLU 激活函数的FFN（Position-wise Feed-Forward Network）的计算公式如下：

## 1.3. 旋转位置嵌入（RoPE）

class LlamaRotaryEmbedding(torch.nn.Module):

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make torch.jit.trace work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# This if block is unlikely to be run after we build sin/cos in __init__.
# Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
persistent=False)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can squeeze them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

## 1.4. 模型整体框架

HuggingFace Transformer 库中LLaMA 解码器整体实现代码实现如下所示：

class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

# 二、注意力机制优化

## 2.1. 稀疏注意力机制

（1）全局注意力（Global Attention）：为了增强模型建模长距离依赖关系，可以加入一些全局节点；

（2）带状注意力（Band Attention）：大部分数据都带有局部性，限制Query 只与相邻的几个节点进行交互；

（3）膨胀注意力（Dilated Attention）；与CNN 中的Dilated Conv 类似，通过增加空隙以获取更大的感受野；

（4）随机注意力（Random Attention）：通过随机采样，提升非局部的交互；

（5）局部块注意力（Block Local Attention）：使用多个不重叠的块（Block）来限制信息交互。

Star-Transformer[54] 使用带状注意力和全局注意力的组合。具体来说，Star-Transformer 只包括一个全局注意力节点和宽度为3 的带状注意力，其中任意两个非相邻节点通过一个共享的全局注意力连接，而相邻节点则直接相连。

Longformer使用带状注意力和内部全局节点注意力（Internal Global-node Attention）的组合。此外，Longformer 还将上层中的一些带状注意力头部替换为具有扩张窗口的注意力，在增加感受野同时并不增加计算量。Extended Transformer Construction（ETC）利用带状注意力和外部全局节点注意力（External Global-node Attention）的组合。ETC 稀疏注意力还包括一种掩码机制来处理结构化输入，并采用对比预测编码（Contrastive Predictive Coding，CPC）进行预训练。

BigBird使用带状和全局注意力，还使用额外的随机注意力来近似全连接注意力，此外还揭示了稀疏编码器和稀疏解码器的使用可以模拟任何图灵机，这也在一定程度上解释了，为什么稀疏注意力模型可以取得较好的结果原因。

## 2.2. FlashAttention

NVIDIA GPU 中的内存（显存）按照它们物理上是在GPU 芯片内部还是板卡RAM 存储芯片上，决定了它们的速度、大小以及访问限制。GPU 显存分为全局内存（Global memory）、本地内存（Local memory）、共享内存（Shared memory，SRAM）、寄存器内存（Register memory）、常量内存（Constant memory）、纹理内存（Texture memory）等六大类。图2.8给出了NVIDIA GPU 内存的整体结构。其中全局内存、本地内存、共享内存和寄存器内存具有读写能力。

NVIDIA H100 中每个GPU 线程块在流式多处理器（Stream Multi-processor，SM）可以使用的共享存储容量仅有228KB，但是其速度非常快，远高于全局内存的访问速度。

FlashAttention就是通过利用GPU 硬件中的特殊设计，针对全局内存和共享存储的I/O 速度的不同，尽可能的避免HBM 中读取或写入注意力矩阵。

FlashAttention 目标是尽可能高效地使用SRAM 来加快计算速度，避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax 函数，并且后向传播中不能存储中间注意力矩阵。

FlashAttention 算法并没有将S、P 整体写入全局内存，而是通过分块写入，存储前向传递的Softmax 归一化因子，在后向传播中快速重新计算片上注意力，这比从全局内容中读取中间注意力矩阵的标准方法更快。由于大幅度减少了全局内存的访问量，即使重新计算导致FLOPs 增加，但其运行速度更快并且使用更少的内存。具体算法如代码2.2所示，其中内循环和外循环所对应的计算可以参考下图。

2.3 FlashAttention 计算流程图

2.3. 多查询注意力

class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
"""
def __init__(
self,
d_model: int,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.Wqkv = nn.Linear( # Multi-Query Attention 创建
d_model,
d_model + 2 * self.head_dim, # 只创建查询的头向量，所以只有1 个d_model
)
self.out_proj = nn.Linear(
self.d_model,
self.d_model,
device=device
)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
):
qkv = self.Wqkv(x) # (1, 512, 960)
query, key, value = qkv.split( # query -> (1, 512, 768)
dim=2 # value -> (1, 512, 96)
)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
multiquery=True,
)
return self.out_proj(context), attn_weights, past_key_value

# Multi Head Attention
self.Wqkv = nn.Linear( # Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # 查询、键和值3 个矩阵, 所以是3 * d_model
device=device
)
query, key, value = qkv.chunk( # 每个tensor 都是(1, 512, 768)
3,
dim=2
)
# Multi Query Attention
self.Wqkv = nn.Linear( # Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建查询的头向量，所以是1* d_model
device=device, # 而键和值不再具备单独的头向量
)
query, key, value = qkv.split( # query -> (1, 512, 768)
dim=2 # value -> (1, 512, 96)
)

0 评论
0 收藏
0