Skip to content

Commit 86ac211

Browse files
committed
Add tests.
1 parent a460877 commit 86ac211

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed

tests/test_archive.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import os
2+
import numpy as np
3+
import pytest
4+
import tempfile
5+
import zipfile
6+
7+
from wfdb import rdrecord, wrsamp
8+
from wfdb.io.archive import WFDBArchive
9+
10+
11+
np.random.seed(1234)
12+
13+
14+
@pytest.fixture
15+
def temp_record():
16+
"""
17+
Create a temporary WFDB record and archive for testing.
18+
19+
This fixture generates a synthetic 2-channel signal, writes it to a temporary
20+
directory using `wrsamp`, then creates an uncompressed `.wfdb` archive (ZIP container)
21+
containing the `.hea` and `.dat` files. The archive is used to test read/write
22+
round-trip support for WFDB archives.
23+
24+
Yields
25+
------
26+
dict
27+
A dictionary containing:
28+
- 'record_name': Path to the record base name (without extension).
29+
- 'archive_path': Full path to the created `.wfdb` archive.
30+
- 'original_signal': The original NumPy array of the signal.
31+
- 'fs': The sampling frequency.
32+
"""
33+
with tempfile.TemporaryDirectory() as tmpdir:
34+
record_basename = "testrecord"
35+
fs = 250
36+
sig_len = 1000
37+
sig = (np.random.randn(sig_len, 2) * 1000).astype(np.float32)
38+
39+
# Write into tmpdir with record name only
40+
wrsamp(
41+
record_name=record_basename,
42+
fs=fs,
43+
units=["mV", "mV"],
44+
sig_name=["I", "II"],
45+
p_signal=sig,
46+
fmt=["24", "24"],
47+
adc_gain=[200.0, 200.0],
48+
baseline=[0, 0],
49+
write_dir=tmpdir,
50+
)
51+
52+
# Construct full paths for archive creation
53+
hea_path = os.path.join(tmpdir, record_basename + ".hea")
54+
dat_path = os.path.join(tmpdir, record_basename + ".dat")
55+
archive_path = os.path.join(tmpdir, record_basename + ".wfdb")
56+
57+
WFDBArchive.create_archive(
58+
None,
59+
file_list=[hea_path, dat_path],
60+
output_path=archive_path,
61+
)
62+
63+
yield {
64+
"record_name": os.path.join(tmpdir, record_basename),
65+
"archive_path": archive_path,
66+
"original_signal": sig,
67+
"fs": fs,
68+
}
69+
70+
71+
def test_wfdb_archive_inline_round_trip():
72+
"""
73+
There are two ways of creating an archive:
74+
75+
1. Inline archive creation via wrsamp(..., wfdb_archive=...)
76+
This creates the .hea and .dat files directly inside the archive as part of the record writing step.
77+
78+
2. Two-step creation via wrsamp(...) followed by WFDBArchive.create_archive(...)
79+
This writes regular WFDB files to disk, which are then added to an archive container afterward.
80+
81+
Test round-trip read/write using inline archive creation via `wrsamp(..., wfdb_archive=...)`.
82+
"""
83+
with tempfile.TemporaryDirectory() as tmpdir:
84+
record_basename = "testrecord"
85+
record_path = os.path.join(tmpdir, record_basename)
86+
archive_path = record_path + ".wfdb"
87+
fs = 250
88+
sig_len = 1000
89+
sig = (np.random.randn(sig_len, 2) * 1000).astype(np.float32)
90+
91+
# Create archive inline
92+
wfdb_archive = WFDBArchive(record_basename, mode="w")
93+
wrsamp(
94+
record_name=record_basename,
95+
fs=fs,
96+
units=["mV", "mV"],
97+
sig_name=["I", "II"],
98+
p_signal=sig,
99+
fmt=["24", "24"],
100+
adc_gain=[200.0, 200.0],
101+
baseline=[0, 0],
102+
write_dir=tmpdir,
103+
wfdb_archive=wfdb_archive,
104+
)
105+
wfdb_archive.close()
106+
107+
assert os.path.exists(archive_path), "Archive was not created"
108+
109+
# Read back from archive
110+
record = rdrecord(archive_path)
111+
112+
assert record.fs == fs
113+
assert record.n_sig == 2
114+
assert record.p_signal.shape == sig.shape
115+
116+
# Add tolerance to account for loss of precision during archive round-trip
117+
np.testing.assert_allclose(record.p_signal, sig, rtol=1e-2, atol=3e-3)
118+
119+
120+
def test_wfdb_archive_round_trip(temp_record):
121+
record_name = temp_record["record_name"]
122+
archive_path = temp_record["archive_path"]
123+
original_signal = temp_record["original_signal"]
124+
fs = temp_record["fs"]
125+
126+
assert os.path.exists(archive_path), "Archive was not created"
127+
128+
record = rdrecord(archive_path)
129+
130+
assert record.fs == fs
131+
assert record.n_sig == 2
132+
assert record.p_signal.shape == original_signal.shape
133+
134+
# Add tolerance to account for loss of precision during archive round-trip
135+
np.testing.assert_allclose(record.p_signal, original_signal, rtol=1e-2,
136+
atol=3e-3)
137+
138+
139+
def test_archive_read_subset_channels(temp_record):
140+
"""
141+
Test reading a subset of channels from an archive.
142+
"""
143+
archive_path = temp_record["archive_path"]
144+
original_signal = temp_record["original_signal"]
145+
146+
record = rdrecord(archive_path, channels=[1])
147+
148+
assert record.n_sig == 1
149+
assert record.p_signal.shape[0] == original_signal.shape[0]
150+
151+
# Add tolerance to account for loss of precision during archive round-trip
152+
np.testing.assert_allclose(record.p_signal[:, 0], original_signal[:, 1],
153+
rtol=1e-2, atol=3e-3)
154+
155+
156+
def test_archive_read_partial_samples(temp_record):
157+
"""
158+
Test reading a sample range from the archive.
159+
"""
160+
archive_path = temp_record["archive_path"]
161+
original_signal = temp_record["original_signal"]
162+
163+
start, stop = 100, 200
164+
record = rdrecord(archive_path, sampfrom=start, sampto=stop)
165+
166+
assert record.p_signal.shape == (stop - start, original_signal.shape[1])
167+
np.testing.assert_allclose(record.p_signal, original_signal[start:stop], rtol=1e-2, atol=1e-3)
168+
169+
170+
def test_archive_missing_file_error(temp_record):
171+
"""
172+
Ensure appropriate error is raised when expected files are missing from the archive.
173+
"""
174+
archive_path = temp_record["archive_path"]
175+
176+
# Remove one file from archive (e.g. the .dat file)
177+
with zipfile.ZipFile(archive_path, "a") as zf:
178+
zf_name = [name for name in zf.namelist() if name.endswith(".dat")][0]
179+
zf.fp = None # Prevent auto-close bug in some zipfile implementations
180+
os.rename(archive_path, archive_path + ".bak")
181+
with zipfile.ZipFile(archive_path + ".bak", "r") as zin, \
182+
zipfile.ZipFile(archive_path, "w") as zout:
183+
for item in zin.infolist():
184+
if not item.filename.endswith(".dat"):
185+
zout.writestr(item, zin.read(item.filename))
186+
os.remove(archive_path + ".bak")
187+
188+
with pytest.raises(FileNotFoundError, match=".*\.dat.*"):
189+
rdrecord(archive_path)

0 commit comments

Comments
 (0)