Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/transformers/models/gemma3n/feature_extraction_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,10 @@ def __call__(
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
is_batched = is_batched_numpy or is_batched_sequence

if is_batched:
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)

if not is_batched: # always return a batch
raw_speech = [np.asarray([raw_speech])]
# Always return a batch
if not is_batched:
raw_speech = [raw_speech]
raw_speech = [np.asarray([rs]).T for rs in raw_speech]

batched_speech = self.pad(
BatchFeature({"input_features": raw_speech}),
Expand Down
7 changes: 7 additions & 0 deletions tests/models/gemma3n/test_feature_extraction_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def test_call(self, audio_inputs, test_truncation=False):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))

def test_call_unbatched(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_audio = floats_list((1, 800))[0]
input_features = feature_extractor(np_audio, return_tensors="np").input_features
expected_input_features = feature_extractor([np_audio], return_tensors="np").input_features
np.testing.assert_allclose(input_features, expected_input_features)

def test_audio_features_attn_mask_consistent(self):
# regression test for https://github.com/huggingface/transformers/issues/39911
# Test input_features and input_features_mask have consistent shape
Expand Down
Loading