Three Month
文章 15
標籤 11
分類 4
為什麼torch.nn.transformer中每個input的feature size需要是head數量的倍數

為什麼torch.nn.transformer中每個input的feature size需要是head數量的倍數

約1.2k字 大概需要5分鐘

前言

在第一次新生訓練時,學長有講解有關transformer的架構,而講解完後,他留了一個小作業,也就是關於 multi-head transformer在pytorch中的實作中,d_model需要被nhead整除,d_model是指說input的feature size,nhead是指說head的數量。這個結果好像跟理論上有些落差,並要我們解釋原因。

Multi-Head Transformer 理論

在學長的講解過程中,有提到對於Multi-Head Transformer是對同一個input做了很多個Self-Attention後,再將多次輸出連接在一起後,再和一個大矩陣相乘,壓回原本的大小,這個過程也就是做加權平均。具體過程如下兩張圖。

Multi-Head Transformer的輸出

將多個輸出壓回一個輸出的過程

而上述的過程中,我們會發現無論輸入x的size是多少,都與head的數量無關,也就是說不論feature size和head數量是多少,理論上都能訓練得起來。

為什麼torch.nn.transformer中d_model需要被nhead整除

上個段落講說其實feature size和head數量無關,但如果對nn.transformer填入任意的feature size和head數量,可能會被提示說feature size需為head數量的倍數。

nn.transformer的報錯範例

至於為什麼,現在慢慢trace source code。
首先先看到nn.transformer: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer
其中與nhead和d_model相關的部分則是TransformerEncoderLayer這個class。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Transformer(Module):
#......
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")

if custom_encoder is not None:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first,
**factory_kwargs)
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
### ......

所以繼續看class TransformerEncoderLayer到底寫了什麼,我們發現在class TransformerEncoderLayer中與d_model和nhead相關的class為MultiheadAttention。

1
2
3
4
5
6
7
8
9
10
11
12
class TransformerEncoderLayer(Module):
#......
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
**factory_kwargs)
#......
#......

為了繼續trace code,我又在網路上找了class MultiheadAttention: https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention
仔細看,他又把forward的過程丟給了F.multi_head_attention_forward。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class MultiheadAttention(Module):
#.......
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
#.......
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
#......

所以我們再去看F.multi_head_attention_forward寫了什麼: https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
#......
)
#......
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
#......
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)

attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None

首先會看到他寫了nhead必須要被embed_dim整除,然後整除後的變數叫做head_dim。

然後順著有head_dim變數的地方找,發現在最後一段中,qkv都變成了4維陣列,再丟進scaled_dot_product_attention運算。
可以仔細觀察四維陣列形狀,分別都是由(batch size, number of head, target/source length, head dimension)組成。如果根據上一段的理論來說,格式應該是(batch size, target/source length, feature size = number of head * head dimension)。在這不難看出qkv變成了4維陣列的理由,就是單純對target/source length, head dimension做矩陣乘法運算,而number of head那個dimension就像是batch size那個dimension一樣,完全不受影響。

簡潔來說torch.nn.transformer的multi-head實作,透過將feature size平均拆成多份,並將每份分給不同的head做運算,最後再接起來。這種作法的好處就是可以在稍微影響效能的情形下,省去大量的運算,因為每個input平均分給不同的head計算,不會有重複餵input再加權平均的情況發生。假如multi-head是8的話,torch.nn.transformer在這部分減少了約7/8的計算量。

心得

在這次trace code的過程中,除了發現理論和實作上的差距,我還發現官方的code,oop寫的比我在交大修的任何AI課還要嚴謹許多,完全不是單純把model刻出來就結束了。在封裝上也很漂亮,也寫了許多assert來規範奇怪的輸入。透過trace code的過程,也能學習到良好的coding習慣。

本文作者:Three Month
本文連結:https://threemonth03.github.io/2023/07/14/2023-07-14-torch.nn.transformer%E7%9A%84%E7%B4%B0%E7%AF%80/
版權聲明:本文使用 CC BY-NC-SA 3.0 CN 協議進行許可
×