|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from transformers import AutoTokenizer |
| 4 | +from labml_nn.transformers.LoRA import Linear, Embedding |
| 5 | + |
| 6 | +tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| 7 | + |
| 8 | +config = { |
| 9 | + "layer_norm_epsilon": 1e-05, |
| 10 | + "n_embd": 768, |
| 11 | + "n_head": 12, |
| 12 | + "n_layer": 12, |
| 13 | + "n_positions": 1024, |
| 14 | + "vocab_size": 50257, |
| 15 | + "device": "cuda" |
| 16 | +} |
| 17 | + |
| 18 | + |
| 19 | +class FFN(nn.Module): |
| 20 | + def __init__(self, dim): |
| 21 | + super().__init__() |
| 22 | + self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True) |
| 23 | + self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True) |
| 24 | + self.act = nn.functional.gelu |
| 25 | + |
| 26 | + def forward(self, hidden_states): |
| 27 | + hidden_states = self.c_fc(hidden_states) |
| 28 | + hidden_states = self.act(hidden_states) |
| 29 | + hidden_states = self.c_proj(hidden_states) |
| 30 | + return hidden_states |
| 31 | + |
| 32 | + |
| 33 | +class MultiHeadAttention(nn.Module): |
| 34 | + def __init__(self): |
| 35 | + super().__init__() |
| 36 | + self.embed_dim = config['n_embd'] |
| 37 | + self.num_heads = config['n_head'] |
| 38 | + self.head_dim = self.embed_dim // self.num_heads |
| 39 | + self.split_size = self.embed_dim |
| 40 | + |
| 41 | + self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True) |
| 42 | + self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True) |
| 43 | + |
| 44 | + def _split_heads(self, tensor, num_heads, attn_head_size): |
| 45 | + """ |
| 46 | + Splits hidden_size dim into attn_head_size and num_heads |
| 47 | + """ |
| 48 | + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) |
| 49 | + tensor = tensor.view(new_shape) |
| 50 | + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) |
| 51 | + |
| 52 | + def forward(self, hidden_states): |
| 53 | + batch_size, seq_length, _ = hidden_states.size() |
| 54 | + |
| 55 | + query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2) |
| 56 | + |
| 57 | + query = self._split_heads(query, self.num_heads, self.head_dim) |
| 58 | + key = self._split_heads(key, self.num_heads, self.head_dim) |
| 59 | + value = self._split_heads(value, self.num_heads, self.head_dim) |
| 60 | + |
| 61 | + attn_output = torch.nn.functional.scaled_dot_product_attention( |
| 62 | + query, |
| 63 | + key, |
| 64 | + value, |
| 65 | + attn_mask=None, |
| 66 | + dropout_p=0.0, |
| 67 | + is_causal=True, # for the triangular mask |
| 68 | + ) |
| 69 | + |
| 70 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 71 | + attn_output = attn_output.view(batch_size, seq_length, self.embed_dim) |
| 72 | + |
| 73 | + attn_output = self.c_proj(attn_output) |
| 74 | + |
| 75 | + return attn_output |
| 76 | + |
| 77 | + |
| 78 | +class Block(nn.Module): |
| 79 | + def __init__(self): |
| 80 | + super().__init__() |
| 81 | + self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) |
| 82 | + self.attn = MultiHeadAttention() |
| 83 | + self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) |
| 84 | + self.ffn = FFN(config['n_embd'] * 4) |
| 85 | + |
| 86 | + def forward(self, hidden_states): |
| 87 | + residual = hidden_states |
| 88 | + hidden_states = self.pre_norm(hidden_states) |
| 89 | + |
| 90 | + attn_output = self.attn(hidden_states) |
| 91 | + |
| 92 | + hidden_states = attn_output + residual |
| 93 | + residual = hidden_states |
| 94 | + hidden_states = self.post_norm(hidden_states) |
| 95 | + feed_forward_output = self.ffn(hidden_states) |
| 96 | + hidden_states = feed_forward_output + residual |
| 97 | + |
| 98 | + return hidden_states |
| 99 | + |
| 100 | + |
| 101 | +class GPTModel(nn.Module): |
| 102 | + def __init__(self): |
| 103 | + super().__init__() |
| 104 | + |
| 105 | + self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32) |
| 106 | + self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32) |
| 107 | + |
| 108 | + self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])]) |
| 109 | + |
| 110 | + self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon']) |
| 111 | + |
| 112 | + self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False) |
| 113 | + |
| 114 | + def forward(self, input_ids): |
| 115 | + batch_size, input_shape = input_ids.size() |
| 116 | + |
| 117 | + token_embeddings = self.token_embedding(input_ids) # B T C |
| 118 | + position_ids = torch.arange(input_shape, device=config['device']) # T C |
| 119 | + position_embeddings = self.position_embedding(position_ids) # B T C |
| 120 | + |
| 121 | + hidden_states = token_embeddings + position_embeddings |
| 122 | + |
| 123 | + for block in self.blocks: |
| 124 | + hidden_states = block(hidden_states) |
| 125 | + |
| 126 | + hidden_states = self.final_norm(hidden_states) |
| 127 | + |
| 128 | + logits = self.lm_head(hidden_states) |
| 129 | + |
| 130 | + return logits |
0 commit comments