Skip to content

Commit 7814876

Browse files
committed
Add support for writing .wfdb files.
1 parent c48abd0 commit 7814876

File tree

4 files changed

+88
-36
lines changed

4 files changed

+88
-36
lines changed

wfdb/io/_header.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import os
23
from typing import Any, Dict, List, Optional, Sequence, Tuple
34

45
import numpy as np
@@ -278,7 +279,7 @@ def set_defaults(self):
278279
for f in sfields:
279280
self.set_default(f)
280281

281-
def wrheader(self, write_dir="", expanded=True):
282+
def wrheader(self, write_dir="", expanded=True, wfdb_archive=None):
282283
"""
283284
Write a WFDB header file. The signals are not used. Before
284285
writing:
@@ -325,7 +326,8 @@ def wrheader(self, write_dir="", expanded=True):
325326
self.check_field_cohesion(rec_write_fields, list(sig_write_fields))
326327

327328
# Write the header file using the specified fields
328-
self.wr_header_file(rec_write_fields, sig_write_fields, write_dir)
329+
self.wr_header_file(rec_write_fields, sig_write_fields, write_dir,
330+
wfdb_archive=wfdb_archive)
329331

330332
def get_write_fields(self):
331333
"""
@@ -508,7 +510,8 @@ def check_field_cohesion(self, rec_write_fields, sig_write_fields):
508510
"Each file_name (dat file) specified must have the same byte offset"
509511
)
510512

511-
def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
513+
def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir,
514+
wfdb_archive=None):
512515
"""
513516
Write a header file using the specified fields. Converts Record
514517
attributes into appropriate WFDB format strings.
@@ -522,6 +525,8 @@ def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
522525
being equal to a list of channels to write for each field.
523526
write_dir : str
524527
The directory in which to write the header file.
528+
wfdb_archive : WFDBArchive, optional
529+
If provided, write the header into this archive instead of to disk.
525530
526531
Returns
527532
-------
@@ -583,7 +588,13 @@ def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
583588
comment_lines = ["# " + comment for comment in self.comments]
584589
header_lines += comment_lines
585590

586-
util.lines_to_file(self.record_name + ".hea", write_dir, header_lines)
591+
header_str = "\n".join(header_lines) + "\n"
592+
hea_filename = os.path.basename(self.record_name) + ".hea"
593+
594+
if wfdb_archive:
595+
wfdb_archive.write(hea_filename, header_str.encode("utf-8"))
596+
else:
597+
util.lines_to_file(hea_filename, write_dir, header_lines)
587598

588599

589600
class MultiHeaderMixin(BaseHeaderMixin):
@@ -621,7 +632,7 @@ def set_defaults(self):
621632
for field in self.get_write_fields():
622633
self.set_default(field)
623634

624-
def wrheader(self, write_dir=""):
635+
def wrheader(self, write_dir="", wfdb_archive=None):
625636
"""
626637
Write a multi-segment WFDB header file. The signals or segments are
627638
not used. Before writing:
@@ -655,7 +666,7 @@ def wrheader(self, write_dir=""):
655666
self.check_field_cohesion()
656667

657668
# Write the header file using the specified fields
658-
self.wr_header_file(write_fields, write_dir)
669+
self.wr_header_file(write_fields, write_dir, wfdb_archive=wfdb_archive)
659670

660671
def get_write_fields(self):
661672
"""
@@ -733,7 +744,7 @@ def check_field_cohesion(self):
733744
"The sum of the 'seg_len' fields do not match the 'sig_len' field"
734745
)
735746

736-
def wr_header_file(self, write_fields, write_dir):
747+
def wr_header_file(self, write_fields, write_dir, wfdb_archive=None):
737748
"""
738749
Write a header file using the specified fields.
739750
@@ -744,6 +755,8 @@ def wr_header_file(self, write_fields, write_dir):
744755
and their dependencies.
745756
write_dir : str
746757
The output directory in which the header is written.
758+
wfdb_archive : WFDBArchive, optional
759+
If provided, write the header into this archive instead of to disk.
747760
748761
Returns
749762
-------
@@ -779,7 +792,13 @@ def wr_header_file(self, write_fields, write_dir):
779792
comment_lines = ["# " + comment for comment in self.comments]
780793
header_lines += comment_lines
781794

782-
util.lines_to_file(self.record_name + ".hea", write_dir, header_lines)
795+
header_str = "\n".join(header_lines) + "\n"
796+
hea_filename = os.path.basename(self.record_name) + ".hea"
797+
798+
if wfdb_archive:
799+
wfdb_archive.write(hea_filename, header_str.encode("utf-8"))
800+
else:
801+
util.lines_to_file(hea_filename, write_dir, header_lines)
783802

784803
def get_sig_segments(self, sig_name=None):
785804
"""

wfdb/io/_signal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2574,7 +2574,7 @@ def wr_dat_file(
25742574
# Write the bytes to the file
25752575
if wfdb_archive:
25762576
with io.BytesIO() as f:
2577-
b_write.tofile(f)
2577+
f.write(b_write.tobytes())
25782578
wfdb_archive.write(os.path.basename(file_name), f.getvalue())
25792579
else:
25802580
with open(file_path, "wb") as f:

wfdb/io/archive.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,42 @@ class WFDBArchive:
1010
"""
1111
Helper class for working with WFDB .wfdb ZIP archives.
1212
13+
If used for reading, the archive must already exist.
14+
If used for writing, use mode='w' and call `write(...)` or `create_archive(...)`.
15+
1316
Used only if:
1417
- .wfdb is included in the record_name explicitly, or
1518
- .wfdb is passed directly to the file loading function.
1619
"""
17-
def __init__(self, record_name):
20+
def __init__(self, record_name, mode="r"):
1821
"""
1922
Initialize a WFDBArchive for a given record name (without extension).
2023
24+
Parameters
25+
----------
2126
record_name : str
22-
The base name of the archive, without the .wfdb extension.
27+
The base name of the archive, without the .wfdb extension.
28+
mode : str
29+
'r' for read (default), 'w' for write.
2330
"""
2431
self.record_name = record_name
2532
self.archive_path = f"{record_name}.wfdb"
33+
self.zipfile = None
34+
self.mode = mode
35+
36+
if mode == "r":
37+
if not os.path.exists(self.archive_path):
38+
raise FileNotFoundError(f"Archive not found: {self.archive_path}")
39+
if not zipfile.is_zipfile(self.archive_path):
40+
raise ValueError(f"Invalid WFDB archive: {self.archive_path}")
41+
self.zipfile = zipfile.ZipFile(self.archive_path, mode="r")
2642

27-
if not os.path.exists(self.archive_path):
28-
raise FileNotFoundError(f"Archive not found: {self.archive_path}")
29-
if not zipfile.is_zipfile(self.archive_path):
30-
raise ValueError(f"Invalid WFDB archive: {self.archive_path}")
31-
self.zipfile = zipfile.ZipFile(self.archive_path, mode="r")
43+
elif mode == "w":
44+
# Initialize an empty archive on disk
45+
if not os.path.exists(self.archive_path):
46+
with zipfile.ZipFile(self.archive_path, mode="w"):
47+
pass # Just create the file
48+
self.zipfile = zipfile.ZipFile(self.archive_path, mode="a")
3249

3350
def exists(self, filename):
3451
"""
@@ -65,16 +82,19 @@ def write(self, filename, data):
6582
"""
6683
Write binary data to the archive (replaces if already exists).
6784
"""
68-
# Write to a new temporary archive
85+
if self.zipfile is None:
86+
self.zipfile = zipfile.ZipFile(self.archive_path, mode="w")
87+
self.zipfile.writestr(filename, data)
88+
return
89+
90+
# If already opened in read or append mode, use the replace-then-move trick
6991
tmp_path = self.archive_path + ".tmp"
7092
with zipfile.ZipFile(self.archive_path, mode="r") as zin:
7193
with zipfile.ZipFile(tmp_path, mode="w") as zout:
7294
for item in zin.infolist():
7395
if item.filename != filename:
7496
zout.writestr(item, zin.read(item.filename))
7597
zout.writestr(filename, data)
76-
77-
# Replace the original archive
7898
shutil.move(tmp_path, self.archive_path)
7999
self.zipfile = zipfile.ZipFile(self.archive_path, mode="a")
80100

@@ -94,10 +114,11 @@ def create_archive(self, file_list, output_path=None):
94114
zf.write(file, arcname=os.path.basename(file), compress_type=compress)
95115

96116

97-
def get_archive(record_base_name):
117+
def get_archive(record_base_name, mode="r"):
98118
"""
99119
Get or create a WFDBArchive for the given record base name.
100120
"""
101121
if record_base_name not in _archive_cache:
102-
_archive_cache[record_base_name] = WFDBArchive(record_base_name)
122+
_archive_cache[record_base_name] = WFDBArchive(record_base_name,
123+
mode=mode)
103124
return _archive_cache[record_base_name]

wfdb/io/record.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -935,11 +935,12 @@ def wrsamp(self, expanded=False, write_dir="", wfdb_archive=None):
935935

936936
# Perform field validity and cohesion checks, and write the
937937
# header file.
938-
self.wrheader(write_dir=write_dir, expanded=expanded)
938+
self.wrheader(write_dir=write_dir, expanded=expanded,
939+
wfdb_archive=wfdb_archive)
939940
if self.n_sig > 0:
940941
# Perform signal validity and cohesion checks, and write the
941942
# associated dat files.
942-
self.wr_dats(expanded=expanded, write_dir=write_dir,
943+
self.wr_dats(expanded=expanded, write_dir=write_dir,
943944
wfdb_archive=wfdb_archive)
944945

945946
def _arrange_fields(self, channels, sampfrom, smooth_frames):
@@ -1162,7 +1163,7 @@ def __init__(
11621163
if not seg_len:
11631164
self.seg_len = [segment.sig_len for segment in segments]
11641165

1165-
def wrsamp(self, write_dir=""):
1166+
def wrsamp(self, write_dir="", wfdb_archive=None):
11661167
"""
11671168
Write a multi-segment header, along with headers and dat files
11681169
for all segments, from this object.
@@ -1179,11 +1180,11 @@ def wrsamp(self, write_dir=""):
11791180
"""
11801181
# Perform field validity and cohesion checks, and write the
11811182
# header file.
1182-
self.wrheader(write_dir=write_dir)
1183+
self.wrheader(write_dir=write_dir, wfdb_archive=wfdb_archive)
11831184
# Perform record validity and cohesion checks, and write the
11841185
# associated segments.
11851186
for seg in self.segments:
1186-
seg.wrsamp(write_dir=write_dir)
1187+
seg.wrsamp(write_dir=write_dir, wfdb_archive=wfdb_archive)
11871188

11881189
def _check_segment_cohesion(self):
11891190
"""
@@ -1828,7 +1829,11 @@ def rdheader(record_name, pn_dir=None, rd_segments=False):
18281829
18291830
"""
18301831
dir_name, base_record_name = os.path.split(record_name)
1831-
file_name = f"{base_record_name}.hea"
1832+
1833+
if not base_record_name.endswith(".hea"):
1834+
file_name = f"{base_record_name}.hea"
1835+
else:
1836+
file_name = base_record_name
18321837

18331838
# If this is a cloud path, use posixpath to construct the path and fsspec to open file
18341839
if any(dir_name.startswith(proto) for proto in CLOUD_PROTOCOLS):
@@ -2032,17 +2037,23 @@ def rdrecord(
20322037
channels=[1, 3])
20332038
20342039
"""
2040+
wfdb_archive = None
20352041
is_wfdb_archive = record_name.endswith(".wfdb")
20362042

20372043
if is_wfdb_archive:
20382044
record_base = record_name[:-5] # remove ".wfdb"
2039-
archive = get_archive(record_base)
2045+
wfdb_archive = get_archive(record_base)
20402046
hea_file = os.path.basename(record_base) + ".hea"
20412047

2042-
with archive.open(hea_file, "r") as f:
2043-
record = Record()
2044-
record.wfdb_archive = archive
2045-
record._read_header(f.read())
2048+
import tempfile
2049+
with wfdb_archive.open(hea_file, "r") as f:
2050+
header_str = f.read()
2051+
2052+
with tempfile.NamedTemporaryFile("w+", suffix=".hea", delete=False) as tmpf:
2053+
tmpf.write(header_str)
2054+
tmpf.flush()
2055+
record = rdheader(tmpf.name)
2056+
record.wfdb_archive = wfdb_archive
20462057

20472058
# Set dir_name to the archive base (needed for _rd_segment)
20482059
dir_name = os.path.dirname(record_base)
@@ -2168,6 +2179,7 @@ def rdrecord(
21682179
no_file=no_file,
21692180
sig_data=sig_data,
21702181
return_res=return_res,
2182+
wfdb_archive=wfdb_archive,
21712183
)
21722184

21732185
# Only 1 sample/frame, or frames are smoothed. Return uniform numpy array
@@ -2879,7 +2891,7 @@ def wrsamp(
28792891
base_date=None,
28802892
base_datetime=None,
28812893
write_dir="",
2882-
archive=False,
2894+
wfdb_archive=None,
28832895
):
28842896
"""
28852897
Write a single segment WFDB record, creating a WFDB header file and any
@@ -3067,9 +3079,9 @@ def wrsamp(
30673079
else:
30683080
expanded = False
30693081

3070-
wfdb_archive = None
3071-
if archive:
3072-
wfdb_archive = get_archive(os.path.join(write_dir, record_name))
3082+
if wfdb_archive:
3083+
wfdb_archive = get_archive(os.path.join(write_dir, record_name),
3084+
mode="w")
30733085

30743086
# Write the record files - header and associated dat
30753087
record.wrsamp(write_dir=write_dir, expanded=expanded, wfdb_archive=wfdb_archive)

0 commit comments

Comments
 (0)