llama3是meta开源的大模型,在开源大模型中占着重要地位,在这之前可能是Mistral,目前也有gemma2,Qwen2以及微软的Phi3等.
llama表现很不错,LLM Leaderboard (toloka.ai),很多模型都是在它基础上微调得到的.
这里将llama介绍分为位置编码,transformer层,ffn层以及其中的norm的改进.
llama3相比于llama2,上下文窗口增大,tokenizer从sentencepiece变为tiktoken,token数也增多了.
Feature | LLaMa 2 | LLaMa 3 |
---|---|---|
Training Data Size | 2 trillion tokens | 15 trillion tokens (7x larger) |
Context Window | 4K tokens | 8k tokens |
Focus Area | General language understanding | Nuance, context, complex tasks |
False Refusal Rate | Higher | Lower |
Response Diversity | Lower | Higher |
Code Generation | Limited capability | Enhanced capability |
旋转位置编码
旋转位置嵌入(RoPE)是一种用于基于transformer模型的技术,可将位置信息纳入标记表示中。与依赖正弦和余弦函数的传统位置编码不同,RoPE 利用旋转矩阵来编码绝对和相对位置信息。这种方法的提出是为了提高位置嵌入在transformer中的有效性。Meta的LLaMA、清华的ChatGLM都采用了RoPE
在llama3中代码如下1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 一个位置上的特征
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # 得到e^freqs^ shape [T,dim//2]
return freqs_cis
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) #shape [bs,seq_len,dim//2,2] -> [bs,seq_len,dim]
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))#shape [bs,seq_len,dim//2,2] -> [bs,seq_len,dim]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1,seq_len,dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)#[bs,seq_len,dim//2,2] type float32 ->
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
计算freqs_cis,其是一个复数,旋转编码通常应用在q和k上.
下面是另一种实现1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47import torch
import torch.nn as nn
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, d: int, base: int = 10_000):
super().__init__()
self.base = base
self.d = d
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
seq_len = x.shape[0]
# THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)
theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Position index [0,1,...]
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
idx_theta = torch.einsum('n,d->nd', seq_idx,
theta) # Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]
idx_theta2 = torch.cat([idx_theta, idx_theta],
dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]
self.cos_cached = idx_theta2.cos()[:, None, None, :] # Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
self.sin_cached = idx_theta2.sin()[:, None, None, :] # cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]
def _neg_half(self, x: torch.Tensor):
d_2 = self.d // 2
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
self._build_cache(x)
neg_half_x = self._neg_half(x)
x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
return x_rope
if __name__ == '__main__':
x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
x = x[:, None, None, :]
p = RotaryPositionalEmbeddings(4)(x)
print(p)
RSMNorm
另一种规范化的方式,方法是在2019年的论文中提出的1910.07467 (arxiv.org)1
2
3
4
5
6
7
8
9
10
11
12
13class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
在llama3中,有三个地方使用,在attention,ffn以及在所有transformer layer之后,1
2h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
GroupQuery Attention&&KVcache
GroupQuery:query的head数是kv的head数的若干倍.
KV cache:在生成新的token时,K和V往往改变不大,也就不需要怎么计算,所以只需要存下计算的值即可.这是节约显存的操作.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
Transformer
1 | class Transformer(nn.Module): |
silu激活函数
1 | class FeedForward(nn.Module): |