Skip to content

CodeWithBehnam/mlx-grokking

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Grokking Modular Arithmetic (MLX)

Grokking

Grokking

Grokking: a phenomenon observed in neural nets, where after an initial phase of overfitting (or memorization), the model suddenly achieves perfect generalization, inspired by Power et al. (2022). We incoporate some modern Transformer tricks (e.g., RoPE, RMSNorm, SiLU, etc.) and achieve grokking in < 150 epochs on modular division when $p=97$ on 50% of the training data using a 2 layer, 1 head, 128 dim net.

Background

We define modular arithmetic for the following operations given a prime modulus $p$ and $(a, b)$ for $0 \leq a \lt p, 0 \lt b \lt p$:

  • Addition: $a \circ b = a + b \mod p$
  • Subtraction: $a \circ b = a - b \mod p$
  • Multiplication: $a \circ b = a \cdot b \mod p$
  • Division: $a \circ b = a / b \mod p$, using Fermat’s Little Theorem which states that $b^{p-1} \equiv 1 \mod p$ for any $b$ not divisible by $p$.

Running

Run with default params for $a / b \mod p$ and save the result in media/grokking.png:

python main.py
  • main.py: training and evaluation loops
  • models.py: defines the Transformer model
  • data.py: generate the dataset

Dependencies

Install the dependencies (optimized for Apple silicon; yay for MLX!):

pip install -r requirements.txt

About

Grokking on modular arithmetic in less than 150 epochs in MLX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%