diff --git a/basalt/nn/__init__.mojo b/basalt/nn/__init__.mojo index c2a0660..11dc4fa 100644 --- a/basalt/nn/__init__.mojo +++ b/basalt/nn/__init__.mojo @@ -11,6 +11,7 @@ from .activations import ( LogSoftmax, ReLU, LeakyReLU, + GELU, Sigmoid, Tanh, ) diff --git a/basalt/nn/activations.mojo b/basalt/nn/activations.mojo index 9a83a0f..d679399 100644 --- a/basalt/nn/activations.mojo +++ b/basalt/nn/activations.mojo @@ -17,6 +17,20 @@ fn LeakyReLU( attributes=AttributeVector(Attribute("negative_slope", negative_slope)), ) +fn GELU(inout g: Graph, input: Symbol) -> Symbol: + var SQRT_2_OVER_PI = 0.7978845608028654 + var GELU_COEFF = 0.044715 + + var x_cubed = g.op(OP.POW, input, 3.0) + var term = g.op(OP.ADD, input, g.op(OP.MUL, GELU_COEFF, x_cubed)) + var scaled_term = g.op(OP.MUL, SQRT_2_OVER_PI, term) + var tanh_result = g.op(OP.TANH, scaled_term) + var one_plus_tanh = g.op(OP.ADD, 1.0, tanh_result) + var gelu_output = g.op(OP.MUL, g.op(OP.MUL, 0.5, input), one_plus_tanh) + + return gelu_output + + fn Sigmoid(inout g: Graph, input: Symbol) -> Symbol: return g.op(OP.SIGMOID, input)