Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
26867ba
Test GPT_OSS files through porter
laxmareddyp Sep 5, 2025
f1c055b
generate API and moved files to respective folders
laxmareddyp Sep 6, 2025
d4da96c
Fix format issues
laxmareddyp Sep 6, 2025
b14cfb5
Add gpt_oss to preset loader and Fix format issues
laxmareddyp Sep 6, 2025
b675610
Add gpt_oss to preset loader
laxmareddyp Sep 6, 2025
8cf71ce
generated files through 2.5-pro model
laxmareddyp Sep 8, 2025
2242ef4
Format fix
laxmareddyp Sep 10, 2025
eb25d19
Add converter, RoPE update
laxmareddyp Sep 11, 2025
ba50a9f
Fix format
laxmareddyp Sep 11, 2025
1854d80
Fix BPE tests
laxmareddyp Sep 12, 2025
76139cd
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 12, 2025
00ec305
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 13, 2025
9447990
Update converter
laxmareddyp Sep 13, 2025
340aa85
Fix converter, checkpoints conversion and attention
laxmareddyp Sep 13, 2025
b02cfea
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 24, 2025
47dcdda
Fix the parameter count and debug code
laxmareddyp Sep 24, 2025
5e16f80
Add dequantization logic to converter
laxmareddyp Sep 25, 2025
79c5664
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Oct 9, 2025
59b6930
Add YaRN support,Fix Serialisation,Fix dequantization
laxmareddyp Oct 9, 2025
8d3a658
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Nov 11, 2025
d9396c6
Fixed several pytest tests
laxmareddyp Nov 11, 2025
4a63e85
Address gpt_oss_causal_lm tests
laxmareddyp Nov 12, 2025
285253f
Fix format issues
laxmareddyp Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
GptOssBackbone as GptOssBackbone,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
GptOssCausalLM as GptOssCausalLM,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
HGNetV2Backbone as HGNetV2Backbone,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.llama.llama_tokenizer import (
LlamaTokenizer as LlamaTokenizer,
)
Expand Down
287 changes: 267 additions & 20 deletions keras_hub/src/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ class RotaryEmbedding(keras.layers.Layer):
curves.
scaling_factor: float. The scaling factor used to scale positions of
the tokens.
rope_type: str. The type of RoPE scaling to apply. Supported types:
"linear", "dynamic", "yarn". Defaults to "linear".
beta_fast: float. Beta fast parameter for YaRN scaling. Only used
when rope_type="yarn". Defaults to 32.0.
beta_slow: float. Beta slow parameter for YaRN scaling. Only used
when rope_type="yarn". Defaults to 1.0.
original_max_position_embeddings: int. Original maximum position
embeddings for YaRN scaling. Only used when rope_type="yarn".
Defaults to 4096.
truncate: bool. Whether to apply truncation for YaRN scaling. Only used
when rope_type="yarn". Defaults to False.
sequence_axis: int. Sequence axis in the input tensor.
feature_axis: int. Feature axis in the input tensor.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
Expand Down Expand Up @@ -69,33 +80,89 @@ def __init__(
self,
max_wavelength=10000,
scaling_factor=1.0,
rope_type="linear",
beta_fast=32.0,
beta_slow=1.0,
original_max_position_embeddings=4096,
truncate=False,
sequence_axis=1,
feature_axis=-1,
**kwargs,
):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength
self.sequence_axis = sequence_axis
self.feature_axis = feature_axis
self.scaling_factor = scaling_factor
self.built = True
self.rope_type = rope_type

# YaRN-specific parameters (only used when rope_type="yarn")
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.original_max_position_embeddings = original_max_position_embeddings
self.truncate = truncate

# Store original axis values for validation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment

self._original_sequence_axis = sequence_axis
self._original_feature_axis = feature_axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.built = True statement was removed from __init__. For layers that do not have weights that need to be built in a build() method, it's important to set self.built = True at the end of __init__ to indicate to the framework that the layer is already built. Please add it back to ensure correct layer state.

Suggested change
self._original_feature_axis = feature_axis
self._original_feature_axis = feature_axis
self.built = True


Comment on lines +102 to +106
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this back to

self.sequence_axis = sequence_axis
self.feature_axis = feature_axis

To avoid the confusion with previous implementation
and add self.built = True

def _normalize_axes(self, input_shape):
"""Normalize and validate axis indices for the given input shape."""
rank = len(input_shape)

# Normalize negative indices
sequence_axis = self._original_sequence_axis
feature_axis = self._original_feature_axis

if sequence_axis < 0:
sequence_axis += rank
if feature_axis < 0:
feature_axis += rank

# Validate axis indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

if sequence_axis < 0 or sequence_axis >= rank:
raise ValueError(
f"sequence_axis {self._original_sequence_axis} "
f"is out of range for input with rank {rank}"
)
if feature_axis < 0 or feature_axis >= rank:
raise ValueError(
f"feature_axis {self._original_feature_axis} "
f"is out of range for input with rank {rank}"
)
if sequence_axis == feature_axis:
raise ValueError("sequence_axis and feature_axis must be different")

return sequence_axis, feature_axis

def _validate_rotary_dimension(self, rotary_dim):
"""Validate that rotary dimension is even and handle odd dimensions."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment

if rotary_dim % 2 != 0:
raise ValueError(
f"Rotary dimension must be even, got {rotary_dim}."
"The rotary embedding splits the feature dimension "
"into two halves. Consider using a different feature "
"dimension or padding."
)

def call(self, inputs, start_index=0, positions=None):
# Normalize and validate axes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment

input_shape = ops.shape(inputs)
sequence_axis, feature_axis = self._normalize_axes(input_shape)

# Validate rotary dimension
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment

rotary_dim = input_shape[feature_axis]
self._validate_rotary_dimension(rotary_dim)

# Take care of unbatched `positions`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment

if positions is not None:
if len(ops.shape(positions)) == 1:
positions = ops.expand_dims(positions, axis=0)

inputs = ops.moveaxis(
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
)
inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1))
cos_emb, sin_emb = self._compute_cos_sin_embedding(
inputs, start_index, positions
)
output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
return ops.moveaxis(
output, (-1, 1), (self.feature_axis, self.sequence_axis)
)
return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis))

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
x1, x2 = ops.split(tensor, 2, axis=-1)
Expand All @@ -113,51 +180,231 @@ def _compute_positions(self, inputs, start_index=0):
return positions + ops.cast(start_index, dtype="float32")

def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
"""Compute cos & sin RoPE embeddings with optional YaRN scaling.
Uses tensor ops only to remain JIT/backends friendly.
"""
batch_axis = 0
feature_axis = len(inputs.shape) - 1
sequence_axis = 1
feature_axis = len(inputs.shape) - 1

# rotary_dim should be half of the last
# feature axis (HF-style: rotate pairs)
rotary_dim = ops.shape(inputs)[feature_axis]
# Validate evenness
try:
# best-effort check when running eagerly;
# if unavailable this will be a no-op
if int(rotary_dim) % 2 != 0:
raise ValueError(
"Rotary embedding requires even feature "
"dimension (last axis)."
)
except Exception:
pass

# Get inverse frequencies using the appropriate
# scaling method (linear, dynamic, yarn, etc.)
Comment on lines +190 to +206
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use try except block, can't we just use _validate_rotary_dimension function?
Also, remove the comments.

inverse_freq = self._get_inverse_freq(rotary_dim)

# positions handling
if positions is None:
positions = self._compute_positions(inputs, start_index)
positions = ops.expand_dims(positions, axis=batch_axis)
positions = ops.expand_dims(
positions, axis=batch_axis
) # shape (1, seq_len)
else:
# ensure float dtype and batch dim
positions = ops.cast(positions, "float32")
positions = positions / ops.cast(self.scaling_factor, "float32")
if len(ops.shape(positions)) == 1:
positions = ops.expand_dims(positions, axis=batch_axis)

# Apply truncation for YaRN if specified
if (
self.rope_type == "yarn"
and self.truncate
and self.original_max_position_embeddings is not None
):
positions = ops.minimum(
positions,
ops.cast(self.original_max_position_embeddings, "float32"),
)

# compute outer product positions x inverse_freq ->
# shape (batch?, seq_len, rotary_dim//2)
# If positions has batch dim, einsum handles it.
freq = ops.einsum("bi,j->bij", positions, inverse_freq)

# stack to interleave sin/cos dims and reshape to full rotary dim
embedding = ops.stack((freq, freq), axis=-2)
embedding = ops.reshape(
embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2)
)

# Expand embedding to match inputs rank
# (insert axes for any non-batch/seq/feature dims)
for axis in range(len(inputs.shape)):
if axis not in (batch_axis, sequence_axis, feature_axis):
embedding = ops.expand_dims(embedding, axis)

cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)

# YaRN temperature scaling: implement in tensor ops
if self.rope_type == "yarn":
# t = (0.1 * ln(s) + 1)^2
# make sure s > 0
small = ops.cast(1e-6, self.compute_dtype)
s_safe = ops.maximum(
ops.cast(self.scaling_factor, self.compute_dtype), small
)
t = ops.square(
ops.add(
ops.multiply(
ops.cast(0.1, self.compute_dtype), ops.log(s_safe)
),
ops.cast(1.0, self.compute_dtype),
)
)
sqrt_t = ops.sqrt(t)

# HF/YaRN descriptions indicate a temperature
# scaling applied to cos/sin embeddings, equivalently
# scaling the logits.We implement the sqrt scaling on cos/sin.
cos_emb = cos_emb * sqrt_t
sin_emb = sin_emb * sqrt_t
Comment on lines +252 to +274
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid variables names like t, small etc, give meaningful variable names.


return cos_emb, sin_emb

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.divide(
ops.arange(0, rotary_dim, 2, dtype="float32"),
ops.cast(rotary_dim, "float32"),
)
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
return inverse_freq
"""Return inverse frequencies."""
# rotary_dim expected to be python int or small tensor;
# create idx with dtype
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
denom = ops.cast(rotary_dim, "float32")
freq_range = idx / denom
inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range)

# apply rope_scaling variants
if self.rope_type == "linear":
# linear: divide inverse freqs by factor
# (consistent with HF linear scaling semantics)
return inv / ops.cast(self.scaling_factor, "float32")
elif self.rope_type == "dynamic":
# dynamic (NTK-aware) fallback conservative implementation:
# HF dynamic implementation uses NTK-by-parts;
# use a practical scaling to approximate.
# Here we conservatively divide
# by scaling_factor^(rotary_dim/(rotary_dim-2))
exponent = ops.cast(rotary_dim, "float32") / ops.cast(
max(1, rotary_dim - 2), "float32"
)
return inv / ops.power(
ops.cast(self.scaling_factor, "float32"), exponent
)
elif self.rope_type == "yarn":
# Delegate to more advanced YaRN inverse freq routine
return self._get_yarn_inverse_freq(inv, rotary_dim)
else:
return inv

def _get_yarn_inverse_freq(self, base_inverse_freq, rotary_dim):
"""YaRN NTK-by-parts style inverse frequency scaling
(tensor-friendly).This follows the YaRN paper and common
porting decisions used in HF forks.
"""
s = ops.cast(self.scaling_factor, "float32")

# Get the base (rope_theta equivalent) from max_wavelength
base = ops.cast(self.max_wavelength, "float32")

# Compute base frequencies: base ** (idx / dim)
idx = ops.arange(0, rotary_dim, 2, dtype="float32")
pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32"))

# Compute interpolation and extrapolation frequencies
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (s * pos_freqs)

# Find correction range (same logic as HuggingFace)
if (
self.beta_fast is not None
and self.beta_slow is not None
and self.original_max_position_embeddings is not None
):
L = ops.cast(self.original_max_position_embeddings, "float32")
beta_fast = ops.cast(self.beta_fast, "float32")
beta_slow = ops.cast(self.beta_slow, "float32")

# Find correction dimensions for beta_fast and beta_slow
def find_correction_dim_tensor(
num_rotations, dim, base_val, max_pos
):
return (
dim
* ops.log(max_pos / (num_rotations * 2 * 3.141592653589793))
) / (2 * ops.log(base_val))
Comment on lines +338 to +345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value of pi is hardcoded here. It's better to use math.pi for precision and readability. You'll need to add import math at the top of the file.

Additionally, the nested function find_correction_dim_tensor does not depend on any instance state and could be defined as a static method on the class or a helper function outside the class for better code organization and to avoid potential JIT compilation issues.

            def find_correction_dim_tensor(
                num_rotations, dim, base_val, max_pos
            ):
                return (
                    dim
                    * ops.log(max_pos / (num_rotations * 2 * math.pi))
                ) / (2 * ops.log(base_val))


low = find_correction_dim_tensor(
beta_fast, ops.cast(rotary_dim, "float32"), base, L
)
high = find_correction_dim_tensor(
beta_slow, ops.cast(rotary_dim, "float32"), base, L
)

# Apply truncation if specified
if self.truncate:
low = ops.floor(low)
high = ops.ceil(high)

# Clamp to valid range
low = ops.maximum(low, ops.cast(0, "float32"))
high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32"))

# Linear ramp function
dim_half = rotary_dim // 2
idx_half = ops.arange(0, dim_half, dtype="float32")

# Prevent singularity
diff = high - low
diff = ops.maximum(diff, ops.cast(0.001, "float32"))

linear_func = (idx_half - low) / diff
ramp_func = ops.clip(linear_func, 0, 1)

# Apply the ramp to get extrapolation factor
inv_freq_extrapolation_factor = 1 - ramp_func

# Combine interpolation and extrapolation
scaled_inverse_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
else:
# Fallback to simple scaling
alpha = ops.power(
s,
ops.cast(rotary_dim, "float32")
/ ops.cast(max(1, rotary_dim - 2), "float32"),
)
scaled_inverse_freq = base_inverse_freq / alpha

return scaled_inverse_freq

def get_config(self):
config = super().get_config()
config.update(
{
"max_wavelength": self.max_wavelength,
"scaling_factor": self.scaling_factor,
"sequence_axis": self.sequence_axis,
"feature_axis": self.feature_axis,
"rope_type": self.rope_type,
"beta_fast": self.beta_fast,
"beta_slow": self.beta_slow,
"original_max_position_embeddings": (
self.original_max_position_embeddings
),
"truncate": self.truncate,
"sequence_axis": self._original_sequence_axis,
"feature_axis": self._original_feature_axis,
}
)
return config
Expand Down
Loading
Loading