Skip to content

Commit 4ee3db5

Browse files
authored
Fix sample rate conversion bug with multi-channel data (#584)
1 parent 4d894d5 commit 4ee3db5

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,17 +1509,22 @@ std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() {
15091509
return std::nullopt;
15101510
}
15111511

1512-
torch::Tensor lastSamples = torch::empty(
1513-
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
1514-
torch::kFloat32);
1515-
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());
1512+
auto numChannels = getNumChannels(streamInfo.codecContext);
1513+
torch::Tensor lastSamples =
1514+
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
1515+
1516+
std::vector<uint8_t*> outputBuffers(numChannels);
1517+
for (auto i = 0; i < numChannels; i++) {
1518+
outputBuffers[i] = static_cast<uint8_t*>(lastSamples[i].data_ptr());
1519+
}
15161520

15171521
auto actualNumRemainingSamples = swr_convert(
15181522
streamInfo.swrContext.get(),
1519-
&lastSamplesData,
1523+
outputBuffers.data(),
15201524
numRemainingSamples,
15211525
nullptr,
15221526
0);
1527+
15231528
return lastSamples.narrow(
15241529
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
15251530
}

test/decoders/test_decoders.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,14 @@ def test_sample_rate_conversion(self, start_seconds, stop_seconds):
11571157
rtol=rtol,
11581158
)
11591159

1160+
def test_sample_rate_conversion_stereo(self):
1161+
# Non-regression test for https://github.com/pytorch/torchcodec/pull/584
1162+
asset = NASA_AUDIO_MP3
1163+
assert asset.sample_rate == 8000
1164+
assert asset.num_channels == 2
1165+
decoder = AudioDecoder(asset.path, sample_rate=44_100)
1166+
decoder.get_samples_played_in_range(start_seconds=0)
1167+
11601168
def test_s16_ffmpeg4_bug(self):
11611169
# s16 fails on FFmpeg4 but can be decoded on other versions.
11621170
# Debugging logs show that we're hitting:

0 commit comments

Comments
 (0)