Skip to content

Commit 5d91fd1

Browse files
authored
Add duration_seconds field to AudioSample (#587)
1 parent 93f5d47 commit 5d91fd1

File tree

5 files changed

+17
-9
lines changed

5 files changed

+17
-9
lines changed

examples/audio_decoding.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,6 @@ def play_audio(samples):
7676
# all streams start exactly at 0! This is not a bug in TorchCodec, this is a
7777
# property of the file that was defined when it was encoded.
7878
#
79-
# We only output the *start* of the samples, not the end or the duration. Those can
80-
# be easily derived from the number of samples and the sample rate:
81-
82-
duration_seconds = samples.data.shape[1] / samples.sample_rate
83-
print(f"Duration = {int(duration_seconds // 60)}m{int(duration_seconds % 60)}s.")
84-
8579
# %%
8680
# Specifying a range
8781
# ------------------

src/torchcodec/_frame.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ class AudioSamples(Iterable):
124124
"""The sample data (``torch.Tensor`` of float in [-1, 1], shape is ``(num_channels, num_samples)``)."""
125125
pts_seconds: float
126126
"""The :term:`pts` of the first sample, in seconds."""
127+
duration_seconds: float
128+
"""The duration of the sampleas, in seconds."""
127129
sample_rate: int
128130
"""The sample rate of the samples, in Hz."""
129131

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ def get_samples_played_in_range(
139139
else:
140140
offset_end = num_samples
141141

142+
data = frames[:, offset_beginning:offset_end]
142143
return AudioSamples(
143-
data=frames[:, offset_beginning:offset_end],
144+
data=data,
144145
pts_seconds=output_pts_seconds,
146+
duration_seconds=data.shape[1] / sample_rate,
145147
sample_rate=sample_rate,
146148
)

test/decoders/test_decoders.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,6 @@ def test_get_all_samples(self, asset, stop_seconds):
993993

994994
torch.testing.assert_close(samples.data, reference_frames)
995995
assert samples.sample_rate == asset.sample_rate
996-
997996
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
998997

999998
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
@@ -1215,3 +1214,10 @@ def test_s16_ffmpeg4_bug(self):
12151214
)
12161215
with cm:
12171216
decoder.get_samples_played_in_range(start_seconds=0)
1217+
1218+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1219+
@pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))
1220+
def test_samples_duration(self, asset, sample_rate):
1221+
decoder = AudioDecoder(asset.path, sample_rate=sample_rate)
1222+
samples = decoder.get_samples_played_in_range(start_seconds=1, stop_seconds=2)
1223+
assert samples.duration_seconds == 1

test/test_frame_dataclasses.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
def test_unpacking():
77
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
8-
data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000)
8+
data, pts_seconds, duration_seconds, sample_rate = AudioSamples(
9+
torch.rand(2, 4), 2, 3, 16_000
10+
)
911

1012

1113
def test_frame_error():
@@ -147,11 +149,13 @@ def test_audio_samples_error():
147149
AudioSamples(
148150
data=torch.rand(1),
149151
pts_seconds=1,
152+
duration_seconds=1,
150153
sample_rate=16_000,
151154
)
152155
with pytest.raises(ValueError, match="data must be 2-dimensional"):
153156
AudioSamples(
154157
data=torch.rand(1, 2, 3),
155158
pts_seconds=1,
159+
duration_seconds=1,
156160
sample_rate=16_000,
157161
)

0 commit comments

Comments
 (0)