Skip to content

Commit 8f4eabb

Browse files
khoAbdennacer-Badaoui
authored andcommitted
Correctly handle unbatched audio inputs in Gemma3nAudioFeatureExtractor (huggingface#42076)
* Correctly handle unbatched audio inputs in Gemma3nAudioFeatureExtractor * Simplify the logic for batching the unbatched speech input in Gemma3nAudioFeatureExtractor
1 parent 5754c67 commit 8f4eabb

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/transformers/models/gemma3n/feature_extraction_gemma3n.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,10 @@ def __call__(
305305
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
306306
is_batched = is_batched_numpy or is_batched_sequence
307307

308-
if is_batched:
309-
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
310-
elif not is_batched and not isinstance(raw_speech, np.ndarray):
311-
raw_speech = np.asarray(raw_speech)
312-
313-
if not is_batched: # always return a batch
314-
raw_speech = [np.asarray([raw_speech])]
308+
# Always return a batch
309+
if not is_batched:
310+
raw_speech = [raw_speech]
311+
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
315312

316313
batched_speech = self.pad(
317314
BatchFeature({"input_features": raw_speech}),

tests/models/gemma3n/test_feature_extraction_gemma3n.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def test_call(self, audio_inputs, test_truncation=False):
228228
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
229229
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
230230

231+
def test_call_unbatched(self):
232+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
233+
np_audio = floats_list((1, 800))[0]
234+
input_features = feature_extractor(np_audio, return_tensors="np").input_features
235+
expected_input_features = feature_extractor([np_audio], return_tensors="np").input_features
236+
np.testing.assert_allclose(input_features, expected_input_features)
237+
231238
def test_audio_features_attn_mask_consistent(self):
232239
# regression test for https://github.com/huggingface/transformers/issues/39911
233240
# Test input_features and input_features_mask have consistent shape

0 commit comments

Comments
 (0)