Grokking: a phenomenon observed in neural nets, where after an initial phase of overfitting (or memorization), the model suddenly achieves perfect generalization, inspired by Power et al. (2022). We incoporate some modern Transformer tricks (e.g., RoPE, RMSNorm, SiLU, etc.) and achieve grokking in < 150 epochs on modular division when
We define modular arithmetic for the following operations given a prime modulus
-
Addition:
$a \circ b = a + b \mod p$ -
Subtraction:
$a \circ b = a - b \mod p$ -
Multiplication:
$a \circ b = a \cdot b \mod p$ -
Division:
$a \circ b = a / b \mod p$ , using Fermat’s Little Theorem which states that$b^{p-1} \equiv 1 \mod p$ for any$b$ not divisible by$p$ .
Run with default params for media/grokking.png:
python main.pymain.py: training and evaluation loopsmodels.py: defines the Transformer modeldata.py: generate the dataset
Install the dependencies (optimized for Apple silicon; yay for MLX!):
pip install -r requirements.txt
