-
Notifications
You must be signed in to change notification settings - Fork 307
Generated GPT_OSS model files through porter script. #2384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
26867ba
f1c055b
d4da96c
b14cfb5
b675610
8cf71ce
2242ef4
eb25d19
ba50a9f
1854d80
76139cd
00ec305
9447990
340aa85
b02cfea
47dcdda
5e16f80
79c5664
59b6930
8d3a658
d9396c6
4a63e85
285253f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -25,6 +25,17 @@ class RotaryEmbedding(keras.layers.Layer): | |||||||
| curves. | ||||||||
| scaling_factor: float. The scaling factor used to scale positions of | ||||||||
| the tokens. | ||||||||
| rope_type: str. The type of RoPE scaling to apply. Supported types: | ||||||||
| "linear", "dynamic", "yarn". Defaults to "linear". | ||||||||
| beta_fast: float. Beta fast parameter for YaRN scaling. Only used | ||||||||
| when rope_type="yarn". Defaults to 32.0. | ||||||||
| beta_slow: float. Beta slow parameter for YaRN scaling. Only used | ||||||||
| when rope_type="yarn". Defaults to 1.0. | ||||||||
| original_max_position_embeddings: int. Original maximum position | ||||||||
| embeddings for YaRN scaling. Only used when rope_type="yarn". | ||||||||
| Defaults to 4096. | ||||||||
| truncate: bool. Whether to apply truncation for YaRN scaling. Only used | ||||||||
| when rope_type="yarn". Defaults to False. | ||||||||
| sequence_axis: int. Sequence axis in the input tensor. | ||||||||
| feature_axis: int. Feature axis in the input tensor. | ||||||||
| **kwargs: other keyword arguments passed to `keras.layers.Layer`, | ||||||||
|
|
@@ -69,33 +80,89 @@ def __init__( | |||||||
| self, | ||||||||
| max_wavelength=10000, | ||||||||
| scaling_factor=1.0, | ||||||||
| rope_type="linear", | ||||||||
| beta_fast=32.0, | ||||||||
| beta_slow=1.0, | ||||||||
| original_max_position_embeddings=4096, | ||||||||
| truncate=False, | ||||||||
| sequence_axis=1, | ||||||||
| feature_axis=-1, | ||||||||
| **kwargs, | ||||||||
| ): | ||||||||
| super().__init__(**kwargs) | ||||||||
| self.max_wavelength = max_wavelength | ||||||||
| self.sequence_axis = sequence_axis | ||||||||
| self.feature_axis = feature_axis | ||||||||
| self.scaling_factor = scaling_factor | ||||||||
| self.built = True | ||||||||
| self.rope_type = rope_type | ||||||||
|
|
||||||||
| # YaRN-specific parameters (only used when rope_type="yarn") | ||||||||
| self.beta_fast = beta_fast | ||||||||
| self.beta_slow = beta_slow | ||||||||
| self.original_max_position_embeddings = original_max_position_embeddings | ||||||||
| self.truncate = truncate | ||||||||
|
|
||||||||
| # Store original axis values for validation | ||||||||
| self._original_sequence_axis = sequence_axis | ||||||||
| self._original_feature_axis = feature_axis | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||
|
|
||||||||
|
Comment on lines
+102
to
+106
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change this back to To avoid the confusion with previous implementation |
||||||||
| def _normalize_axes(self, input_shape): | ||||||||
| """Normalize and validate axis indices for the given input shape.""" | ||||||||
| rank = len(input_shape) | ||||||||
|
|
||||||||
| # Normalize negative indices | ||||||||
| sequence_axis = self._original_sequence_axis | ||||||||
| feature_axis = self._original_feature_axis | ||||||||
|
|
||||||||
| if sequence_axis < 0: | ||||||||
| sequence_axis += rank | ||||||||
| if feature_axis < 0: | ||||||||
| feature_axis += rank | ||||||||
|
|
||||||||
| # Validate axis indices | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this comment |
||||||||
| if sequence_axis < 0 or sequence_axis >= rank: | ||||||||
| raise ValueError( | ||||||||
| f"sequence_axis {self._original_sequence_axis} " | ||||||||
| f"is out of range for input with rank {rank}" | ||||||||
| ) | ||||||||
| if feature_axis < 0 or feature_axis >= rank: | ||||||||
| raise ValueError( | ||||||||
| f"feature_axis {self._original_feature_axis} " | ||||||||
| f"is out of range for input with rank {rank}" | ||||||||
| ) | ||||||||
| if sequence_axis == feature_axis: | ||||||||
| raise ValueError("sequence_axis and feature_axis must be different") | ||||||||
|
|
||||||||
| return sequence_axis, feature_axis | ||||||||
|
|
||||||||
| def _validate_rotary_dimension(self, rotary_dim): | ||||||||
| """Validate that rotary dimension is even and handle odd dimensions.""" | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this comment |
||||||||
| if rotary_dim % 2 != 0: | ||||||||
| raise ValueError( | ||||||||
| f"Rotary dimension must be even, got {rotary_dim}." | ||||||||
| "The rotary embedding splits the feature dimension " | ||||||||
| "into two halves. Consider using a different feature " | ||||||||
| "dimension or padding." | ||||||||
| ) | ||||||||
|
|
||||||||
| def call(self, inputs, start_index=0, positions=None): | ||||||||
| # Normalize and validate axes | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this comment |
||||||||
| input_shape = ops.shape(inputs) | ||||||||
| sequence_axis, feature_axis = self._normalize_axes(input_shape) | ||||||||
|
|
||||||||
| # Validate rotary dimension | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this comment |
||||||||
| rotary_dim = input_shape[feature_axis] | ||||||||
| self._validate_rotary_dimension(rotary_dim) | ||||||||
|
|
||||||||
| # Take care of unbatched `positions`. | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this comment |
||||||||
| if positions is not None: | ||||||||
| if len(ops.shape(positions)) == 1: | ||||||||
| positions = ops.expand_dims(positions, axis=0) | ||||||||
|
|
||||||||
| inputs = ops.moveaxis( | ||||||||
| inputs, (self.feature_axis, self.sequence_axis), (-1, 1) | ||||||||
| ) | ||||||||
| inputs = ops.moveaxis(inputs, (feature_axis, sequence_axis), (-1, 1)) | ||||||||
| cos_emb, sin_emb = self._compute_cos_sin_embedding( | ||||||||
| inputs, start_index, positions | ||||||||
| ) | ||||||||
| output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) | ||||||||
| return ops.moveaxis( | ||||||||
| output, (-1, 1), (self.feature_axis, self.sequence_axis) | ||||||||
| ) | ||||||||
| return ops.moveaxis(output, (-1, 1), (feature_axis, sequence_axis)) | ||||||||
|
|
||||||||
| def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): | ||||||||
| x1, x2 = ops.split(tensor, 2, axis=-1) | ||||||||
|
|
@@ -113,51 +180,231 @@ def _compute_positions(self, inputs, start_index=0): | |||||||
| return positions + ops.cast(start_index, dtype="float32") | ||||||||
|
|
||||||||
| def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None): | ||||||||
| """Compute cos & sin RoPE embeddings with optional YaRN scaling. | ||||||||
| Uses tensor ops only to remain JIT/backends friendly. | ||||||||
| """ | ||||||||
| batch_axis = 0 | ||||||||
| feature_axis = len(inputs.shape) - 1 | ||||||||
| sequence_axis = 1 | ||||||||
| feature_axis = len(inputs.shape) - 1 | ||||||||
|
|
||||||||
| # rotary_dim should be half of the last | ||||||||
| # feature axis (HF-style: rotate pairs) | ||||||||
| rotary_dim = ops.shape(inputs)[feature_axis] | ||||||||
| # Validate evenness | ||||||||
| try: | ||||||||
| # best-effort check when running eagerly; | ||||||||
| # if unavailable this will be a no-op | ||||||||
| if int(rotary_dim) % 2 != 0: | ||||||||
| raise ValueError( | ||||||||
| "Rotary embedding requires even feature " | ||||||||
| "dimension (last axis)." | ||||||||
| ) | ||||||||
| except Exception: | ||||||||
| pass | ||||||||
|
|
||||||||
| # Get inverse frequencies using the appropriate | ||||||||
| # scaling method (linear, dynamic, yarn, etc.) | ||||||||
|
Comment on lines
+190
to
+206
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not use try except block, can't we just use |
||||||||
| inverse_freq = self._get_inverse_freq(rotary_dim) | ||||||||
|
|
||||||||
| # positions handling | ||||||||
| if positions is None: | ||||||||
| positions = self._compute_positions(inputs, start_index) | ||||||||
| positions = ops.expand_dims(positions, axis=batch_axis) | ||||||||
| positions = ops.expand_dims( | ||||||||
| positions, axis=batch_axis | ||||||||
| ) # shape (1, seq_len) | ||||||||
| else: | ||||||||
| # ensure float dtype and batch dim | ||||||||
| positions = ops.cast(positions, "float32") | ||||||||
| positions = positions / ops.cast(self.scaling_factor, "float32") | ||||||||
| if len(ops.shape(positions)) == 1: | ||||||||
| positions = ops.expand_dims(positions, axis=batch_axis) | ||||||||
|
|
||||||||
| # Apply truncation for YaRN if specified | ||||||||
| if ( | ||||||||
| self.rope_type == "yarn" | ||||||||
| and self.truncate | ||||||||
| and self.original_max_position_embeddings is not None | ||||||||
| ): | ||||||||
| positions = ops.minimum( | ||||||||
| positions, | ||||||||
| ops.cast(self.original_max_position_embeddings, "float32"), | ||||||||
| ) | ||||||||
|
|
||||||||
| # compute outer product positions x inverse_freq -> | ||||||||
| # shape (batch?, seq_len, rotary_dim//2) | ||||||||
| # If positions has batch dim, einsum handles it. | ||||||||
| freq = ops.einsum("bi,j->bij", positions, inverse_freq) | ||||||||
|
|
||||||||
| # stack to interleave sin/cos dims and reshape to full rotary dim | ||||||||
| embedding = ops.stack((freq, freq), axis=-2) | ||||||||
| embedding = ops.reshape( | ||||||||
| embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) | ||||||||
| ) | ||||||||
|
|
||||||||
| # Expand embedding to match inputs rank | ||||||||
| # (insert axes for any non-batch/seq/feature dims) | ||||||||
| for axis in range(len(inputs.shape)): | ||||||||
| if axis not in (batch_axis, sequence_axis, feature_axis): | ||||||||
| embedding = ops.expand_dims(embedding, axis) | ||||||||
|
|
||||||||
| cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) | ||||||||
| sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) | ||||||||
|
|
||||||||
| # YaRN temperature scaling: implement in tensor ops | ||||||||
| if self.rope_type == "yarn": | ||||||||
| # t = (0.1 * ln(s) + 1)^2 | ||||||||
| # make sure s > 0 | ||||||||
| small = ops.cast(1e-6, self.compute_dtype) | ||||||||
| s_safe = ops.maximum( | ||||||||
| ops.cast(self.scaling_factor, self.compute_dtype), small | ||||||||
| ) | ||||||||
| t = ops.square( | ||||||||
| ops.add( | ||||||||
| ops.multiply( | ||||||||
| ops.cast(0.1, self.compute_dtype), ops.log(s_safe) | ||||||||
| ), | ||||||||
| ops.cast(1.0, self.compute_dtype), | ||||||||
| ) | ||||||||
| ) | ||||||||
| sqrt_t = ops.sqrt(t) | ||||||||
|
|
||||||||
| # HF/YaRN descriptions indicate a temperature | ||||||||
| # scaling applied to cos/sin embeddings, equivalently | ||||||||
| # scaling the logits.We implement the sqrt scaling on cos/sin. | ||||||||
| cos_emb = cos_emb * sqrt_t | ||||||||
| sin_emb = sin_emb * sqrt_t | ||||||||
|
Comment on lines
+252
to
+274
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid variables names like |
||||||||
|
|
||||||||
| return cos_emb, sin_emb | ||||||||
|
|
||||||||
| def _get_inverse_freq(self, rotary_dim): | ||||||||
| freq_range = ops.divide( | ||||||||
| ops.arange(0, rotary_dim, 2, dtype="float32"), | ||||||||
| ops.cast(rotary_dim, "float32"), | ||||||||
| ) | ||||||||
| inverse_freq = 1.0 / (self.max_wavelength**freq_range) | ||||||||
| return inverse_freq | ||||||||
| """Return inverse frequencies.""" | ||||||||
| # rotary_dim expected to be python int or small tensor; | ||||||||
| # create idx with dtype | ||||||||
| idx = ops.arange(0, rotary_dim, 2, dtype="float32") | ||||||||
| denom = ops.cast(rotary_dim, "float32") | ||||||||
| freq_range = idx / denom | ||||||||
| inv = ops.power(ops.cast(self.max_wavelength, "float32"), -freq_range) | ||||||||
|
|
||||||||
| # apply rope_scaling variants | ||||||||
| if self.rope_type == "linear": | ||||||||
| # linear: divide inverse freqs by factor | ||||||||
| # (consistent with HF linear scaling semantics) | ||||||||
| return inv / ops.cast(self.scaling_factor, "float32") | ||||||||
| elif self.rope_type == "dynamic": | ||||||||
| # dynamic (NTK-aware) fallback conservative implementation: | ||||||||
| # HF dynamic implementation uses NTK-by-parts; | ||||||||
| # use a practical scaling to approximate. | ||||||||
| # Here we conservatively divide | ||||||||
| # by scaling_factor^(rotary_dim/(rotary_dim-2)) | ||||||||
| exponent = ops.cast(rotary_dim, "float32") / ops.cast( | ||||||||
| max(1, rotary_dim - 2), "float32" | ||||||||
| ) | ||||||||
| return inv / ops.power( | ||||||||
| ops.cast(self.scaling_factor, "float32"), exponent | ||||||||
| ) | ||||||||
| elif self.rope_type == "yarn": | ||||||||
| # Delegate to more advanced YaRN inverse freq routine | ||||||||
| return self._get_yarn_inverse_freq(inv, rotary_dim) | ||||||||
| else: | ||||||||
| return inv | ||||||||
|
|
||||||||
| def _get_yarn_inverse_freq(self, base_inverse_freq, rotary_dim): | ||||||||
| """YaRN NTK-by-parts style inverse frequency scaling | ||||||||
| (tensor-friendly).This follows the YaRN paper and common | ||||||||
| porting decisions used in HF forks. | ||||||||
| """ | ||||||||
| s = ops.cast(self.scaling_factor, "float32") | ||||||||
|
|
||||||||
| # Get the base (rope_theta equivalent) from max_wavelength | ||||||||
| base = ops.cast(self.max_wavelength, "float32") | ||||||||
|
|
||||||||
| # Compute base frequencies: base ** (idx / dim) | ||||||||
| idx = ops.arange(0, rotary_dim, 2, dtype="float32") | ||||||||
| pos_freqs = ops.power(base, idx / ops.cast(rotary_dim, "float32")) | ||||||||
|
|
||||||||
| # Compute interpolation and extrapolation frequencies | ||||||||
| inv_freq_extrapolation = 1.0 / pos_freqs | ||||||||
| inv_freq_interpolation = 1.0 / (s * pos_freqs) | ||||||||
|
|
||||||||
| # Find correction range (same logic as HuggingFace) | ||||||||
| if ( | ||||||||
| self.beta_fast is not None | ||||||||
| and self.beta_slow is not None | ||||||||
| and self.original_max_position_embeddings is not None | ||||||||
| ): | ||||||||
| L = ops.cast(self.original_max_position_embeddings, "float32") | ||||||||
| beta_fast = ops.cast(self.beta_fast, "float32") | ||||||||
| beta_slow = ops.cast(self.beta_slow, "float32") | ||||||||
|
|
||||||||
| # Find correction dimensions for beta_fast and beta_slow | ||||||||
| def find_correction_dim_tensor( | ||||||||
| num_rotations, dim, base_val, max_pos | ||||||||
| ): | ||||||||
| return ( | ||||||||
| dim | ||||||||
| * ops.log(max_pos / (num_rotations * 2 * 3.141592653589793)) | ||||||||
| ) / (2 * ops.log(base_val)) | ||||||||
|
Comment on lines
+338
to
+345
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value of pi is hardcoded here. It's better to use Additionally, the nested function def find_correction_dim_tensor(
num_rotations, dim, base_val, max_pos
):
return (
dim
* ops.log(max_pos / (num_rotations * 2 * math.pi))
) / (2 * ops.log(base_val)) |
||||||||
|
|
||||||||
| low = find_correction_dim_tensor( | ||||||||
| beta_fast, ops.cast(rotary_dim, "float32"), base, L | ||||||||
| ) | ||||||||
| high = find_correction_dim_tensor( | ||||||||
| beta_slow, ops.cast(rotary_dim, "float32"), base, L | ||||||||
| ) | ||||||||
|
|
||||||||
| # Apply truncation if specified | ||||||||
| if self.truncate: | ||||||||
| low = ops.floor(low) | ||||||||
| high = ops.ceil(high) | ||||||||
|
|
||||||||
| # Clamp to valid range | ||||||||
| low = ops.maximum(low, ops.cast(0, "float32")) | ||||||||
| high = ops.minimum(high, ops.cast(rotary_dim // 2 - 1, "float32")) | ||||||||
|
|
||||||||
| # Linear ramp function | ||||||||
| dim_half = rotary_dim // 2 | ||||||||
| idx_half = ops.arange(0, dim_half, dtype="float32") | ||||||||
|
|
||||||||
| # Prevent singularity | ||||||||
| diff = high - low | ||||||||
| diff = ops.maximum(diff, ops.cast(0.001, "float32")) | ||||||||
|
|
||||||||
| linear_func = (idx_half - low) / diff | ||||||||
| ramp_func = ops.clip(linear_func, 0, 1) | ||||||||
|
|
||||||||
| # Apply the ramp to get extrapolation factor | ||||||||
| inv_freq_extrapolation_factor = 1 - ramp_func | ||||||||
|
|
||||||||
| # Combine interpolation and extrapolation | ||||||||
| scaled_inverse_freq = ( | ||||||||
| inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) | ||||||||
| + inv_freq_extrapolation * inv_freq_extrapolation_factor | ||||||||
| ) | ||||||||
| else: | ||||||||
| # Fallback to simple scaling | ||||||||
| alpha = ops.power( | ||||||||
| s, | ||||||||
| ops.cast(rotary_dim, "float32") | ||||||||
| / ops.cast(max(1, rotary_dim - 2), "float32"), | ||||||||
| ) | ||||||||
| scaled_inverse_freq = base_inverse_freq / alpha | ||||||||
|
|
||||||||
| return scaled_inverse_freq | ||||||||
|
|
||||||||
| def get_config(self): | ||||||||
| config = super().get_config() | ||||||||
| config.update( | ||||||||
| { | ||||||||
| "max_wavelength": self.max_wavelength, | ||||||||
| "scaling_factor": self.scaling_factor, | ||||||||
| "sequence_axis": self.sequence_axis, | ||||||||
| "feature_axis": self.feature_axis, | ||||||||
| "rope_type": self.rope_type, | ||||||||
| "beta_fast": self.beta_fast, | ||||||||
| "beta_slow": self.beta_slow, | ||||||||
| "original_max_position_embeddings": ( | ||||||||
| self.original_max_position_embeddings | ||||||||
| ), | ||||||||
| "truncate": self.truncate, | ||||||||
| "sequence_axis": self._original_sequence_axis, | ||||||||
| "feature_axis": self._original_feature_axis, | ||||||||
| } | ||||||||
| ) | ||||||||
| return config | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment