diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py index 7dcc4e2c5ca8..a64148e98b41 100644 --- a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py +++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py @@ -305,13 +305,10 @@ def __call__( is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence)) is_batched = is_batched_numpy or is_batched_sequence - if is_batched: - raw_speech = [np.asarray([rs]).T for rs in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech) - - if not is_batched: # always return a batch - raw_speech = [np.asarray([raw_speech])] + # Always return a batch + if not is_batched: + raw_speech = [raw_speech] + raw_speech = [np.asarray([rs]).T for rs in raw_speech] batched_speech = self.pad( BatchFeature({"input_features": raw_speech}), diff --git a/tests/models/gemma3n/test_feature_extraction_gemma3n.py b/tests/models/gemma3n/test_feature_extraction_gemma3n.py index 92359040c9d3..fd3db8156e0d 100644 --- a/tests/models/gemma3n/test_feature_extraction_gemma3n.py +++ b/tests/models/gemma3n/test_feature_extraction_gemma3n.py @@ -228,6 +228,13 @@ def test_call(self, audio_inputs, test_truncation=False): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_call_unbatched(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + np_audio = floats_list((1, 800))[0] + input_features = feature_extractor(np_audio, return_tensors="np").input_features + expected_input_features = feature_extractor([np_audio], return_tensors="np").input_features + np.testing.assert_allclose(input_features, expected_input_features) + def test_audio_features_attn_mask_consistent(self): # regression test for https://github.com/huggingface/transformers/issues/39911 # Test input_features and input_features_mask have consistent shape