66class FFN (nn .Module ):
77 def __init__ (self , dim : int , n_embed : int , r : int ):
88 super ().__init__ ()
9- # lin1
10- self .c_fc = Linear (n_embed , dim , r = r , bias = True )
11- # lin2
12- self .c_proj = Linear (dim , n_embed , r = r , bias = True )
9+ self .linear_in = Linear (n_embed , dim , r = r , bias = True )
10+ self .linear_out = Linear (dim , n_embed , r = r , bias = True )
1311 self .act = nn .functional .gelu
1412
1513 def forward (self , hidden_states ):
16- hidden_states = self .c_fc (hidden_states )
14+ hidden_states = self .linear_in (hidden_states )
1715 hidden_states = self .act (hidden_states )
18- hidden_states = self .c_proj (hidden_states )
16+ hidden_states = self .linear_out (hidden_states )
1917 return hidden_states
2018
2119
@@ -27,10 +25,10 @@ def __init__(self, n_embed: int, r: int):
2725 self .head_dim = self .embed_dim // self .num_heads
2826 self .split_size = self .embed_dim
2927
30- # qkv
31- self .c_att = Linear (n_embed , n_embed * 3 , r = r , bias = True )
28+ # query key value
29+ self .qkv_projection = Linear (n_embed , n_embed * 3 , r = r , bias = True )
3230 # out
33- self .c_proj = Linear (n_embed , n_embed , r = r , bias = True )
31+ self .output_projection = Linear (n_embed , n_embed , r = r , bias = True )
3432
3533 def _split_heads (self , tensor , num_heads , attn_head_size ):
3634 """
@@ -43,7 +41,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
4341 def forward (self , hidden_states ):
4442 batch_size , seq_length , _ = hidden_states .size ()
4543
46- query , key , value = self .c_att (hidden_states ).split (self .split_size , dim = 2 )
44+ query , key , value = self .qkv_projection (hidden_states ).split (self .split_size , dim = 2 )
4745
4846 query = self ._split_heads (query , self .num_heads , self .head_dim )
4947 key = self ._split_heads (key , self .num_heads , self .head_dim )
@@ -61,7 +59,7 @@ def forward(self, hidden_states):
6159 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
6260 attn_output = attn_output .view (batch_size , seq_length , self .embed_dim )
6361
64- attn_output = self .c_proj (attn_output )
62+ attn_output = self .output_projection (attn_output )
6563
6664 return attn_output
6765
0 commit comments